{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## General Information" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook is used to train a simple neural network model to predict the chemistry in the barite benchmark (50x50 grid). The training data is stored in the repository using **git large file storage** and can be downloaded after the installation of git lfs using the `git lfs pull` command.\n", "\n", "It is then recommended to create a Python environment using miniconda. The necessary dependencies are contained in `environment.yml` and can be installed using `conda env create -f environment.yml`.\n", "\n", "The data set is divided into a design and result part and consists of the iterations of a reference simulation. The design part of the data set contains the chemical concentrations at time $t$ and the result part at time $t+1$, which are to be learned by the model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup Libraries" ] }, { "cell_type": "code", "execution_count": 197, "metadata": {}, "outputs": [], "source": [ "import keras\n", "import tensorflow as tf\n", "import h5py\n", "import numpy as np\n", "import pandas as pd\n", "import time\n", "import sklearn.model_selection as sk\n", "import matplotlib.pyplot as plt\n", "from sklearn.cluster import KMeans\n", "from sklearn.pipeline import Pipeline, make_pipeline\n", "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n", "from imblearn.over_sampling import SMOTE\n", "from imblearn.under_sampling import RandomUnderSampler\n", "from imblearn.over_sampling import RandomOverSampler\n", "from collections import Counter\n", "import os\n", "from preprocessing import *\n", "from sklearn import set_config\n", "from importlib import reload\n", "set_config(transform_output = \"pandas\")" ] }, { "cell_type": "code", "execution_count": 136, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The autoreload extension is already loaded. To reload it, use:\n", " %reload_ext autoreload\n" ] } ], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dtype = \"float32\"\n", "activation = \"relu\"\n", "\n", "lr = 0.001\n", "batch_size = 512\n", "epochs = 50 # default 400 epochs\n", "\n", "lr_schedule = keras.optimizers.schedules.ExponentialDecay(\n", " initial_learning_rate=lr,\n", " decay_steps=2000,\n", " decay_rate=0.9,\n", " staircase=True\n", ")\n", "\n", "optimizer_simple = keras.optimizers.Adam(learning_rate=lr_schedule)\n", "optimizer_large = keras.optimizers.Adam(learning_rate=lr_schedule)\n", "optimizer_paper = keras.optimizers.Adam(learning_rate=lr_schedule)\n", "\n", "\n", "loss = keras.losses.Huber()\n", "\n", "sample_fraction = 0.8" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup the model" ] }, { "cell_type": "code", "execution_count": 150, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Model: \"sequential_7\"\n",
"\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential_7\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_27 (Dense) │ (None, 128) │ 1,280 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_28 (Dense) │ (None, 128) │ 16,512 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_29 (Dense) │ (None, 9) │ 1,161 │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_27 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m1,280\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_28 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m16,512\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_29 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m1,161\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Total params: 18,953 (74.04 KB)\n", "\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m18,953\u001b[0m (74.04 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Trainable params: 18,953 (74.04 KB)\n", "\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m18,953\u001b[0m (74.04 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Non-trainable params: 0 (0.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# small model\n", "model_simple = keras.Sequential(\n", " [\n", " keras.Input(shape = (9,), dtype = \"float32\"),\n", " keras.layers.Dense(units = 128, activation = \"relu\", dtype = \"float32\"),\n", " keras.layers.Dense(units = 128, activation = \"relu\", dtype = \"float32\"),\n", " keras.layers.Dense(units = 9, dtype = \"float32\")\n", " ]\n", ")\n", "\n", "model_simple.compile(optimizer=optimizer_simple, loss = loss)\n", "model_simple.summary()" ] }, { "cell_type": "code", "execution_count": 161, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Model: \"sequential_8\"\n",
"\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential_8\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_30 (Dense) │ (None, 512) │ 5,120 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_31 (Dense) │ (None, 1024) │ 525,312 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_32 (Dense) │ (None, 512) │ 524,800 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_33 (Dense) │ (None, 9) │ 4,617 │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_30 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m5,120\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_31 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1024\u001b[0m) │ \u001b[38;5;34m525,312\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_32 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m524,800\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_33 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m4,617\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Total params: 1,059,849 (4.04 MB)\n", "\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,059,849\u001b[0m (4.04 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Trainable params: 1,059,849 (4.04 MB)\n", "\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,059,849\u001b[0m (4.04 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Non-trainable params: 0 (0.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# large model\n", "model_large = keras.Sequential(\n", " [keras.layers.Input(shape=(9,), dtype=dtype),\n", " keras.layers.Dense(512, activation='relu', dtype=dtype),\n", " keras.layers.Dense(1024, activation='relu', dtype=dtype),\n", " keras.layers.Dense(512, activation='relu', dtype=dtype),\n", " keras.layers.Dense(9, dtype=dtype)\n", " ])\n", "\n", "model_large.compile(optimizer=optimizer_large, loss = loss)\n", "model_large.summary()\n" ] }, { "cell_type": "code", "execution_count": 140, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Model: \"sequential_5\"\n",
"\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential_5\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_19 (Dense) │ (None, 128) │ 1,664 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_20 (Dense) │ (None, 256) │ 33,024 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_21 (Dense) │ (None, 512) │ 131,584 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_22 (Dense) │ (None, 256) │ 131,328 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_23 (Dense) │ (None, 12) │ 3,084 │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_19 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m1,664\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_20 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m33,024\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_21 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m131,584\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_22 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m131,328\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_23 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m3,084\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Total params: 300,684 (1.15 MB)\n", "\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m300,684\u001b[0m (1.15 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Trainable params: 300,684 (1.15 MB)\n", "\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m300,684\u001b[0m (1.15 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Non-trainable params: 0 (0.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# model from paper\n", "# (see https://doi.org/10.1007/s11242-022-01779-3 model for the complex chemistry)\n", "model_paper = keras.Sequential(\n", " [keras.layers.Input(shape=(12,), dtype=dtype),\n", " keras.layers.Dense(128, activation='relu', dtype=dtype),\n", " keras.layers.Dense(256, activation='relu', dtype=dtype),\n", " keras.layers.Dense(512, activation='relu', dtype=dtype),\n", " keras.layers.Dense(256, activation='relu', dtype=dtype),\n", " keras.layers.Dense(12, dtype=dtype)\n", " ])\n", "\n", "model_paper.compile(optimizer=optimizer_paper, loss = loss)\n", "model_paper.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define transformer functions" ] }, { "cell_type": "code", "execution_count": 141, "metadata": {}, "outputs": [], "source": [ "def Safelog(val):\n", " # get range of vector\n", " if val > 0:\n", " return np.log10(val)\n", " elif val < 0:\n", " return -np.log10(-val)\n", " else:\n", " return 0\n", "\n", "def Safeexp(val):\n", " if val > 0:\n", " return -10 ** -val\n", " elif val < 0:\n", " return 10 ** val\n", " else:\n", " return 0" ] }, { "cell_type": "code", "execution_count": 142, "metadata": {}, "outputs": [], "source": [ "# ? Why does the charge is using another logarithm than the other species\n", "\n", "func_dict_in = {\n", " \"H\" : np.log1p,\n", " \"O\" : np.log1p,\n", " \"Charge\" : Safelog,\n", " \"H_0_\" : np.log1p,\n", " \"O_0_\" : np.log1p,\n", " \"Ba\" : np.log1p,\n", " \"Cl\" : np.log1p,\n", " \"S_2_\" : np.log1p,\n", " \"S_6_\" : np.log1p,\n", " \"Sr\" : np.log1p,\n", " \"Barite\" : np.log1p,\n", " \"Celestite\" : np.log1p,\n", "}\n", "\n", "func_dict_out = {\n", " \"H\" : np.expm1,\n", " \"O\" : np.expm1,\n", " \"Charge\" : Safeexp,\n", " \"H_0_\" : np.expm1,\n", " \"O_0_\" : np.expm1,\n", " \"Ba\" : np.expm1,\n", " \"Cl\" : np.expm1,\n", " \"S_2_\" : np.expm1,\n", " \"S_6_\" : np.expm1,\n", " \"Sr\" : np.expm1,\n", " \"Barite\" : np.expm1,\n", " \"Celestite\" : np.expm1,\n", "}\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Read data from `.h5` file and convert it to a `pandas.DataFrame`" ] }, { "cell_type": "code", "execution_count": 143, "metadata": {}, "outputs": [], "source": [ "# os.chdir('/mnt/beegfs/home/signer/projects/model-training')\n", "data_file = h5py.File(\"Barite_50_Data_training.h5\")\n", "\n", "design = data_file[\"design\"]\n", "results = data_file[\"result\"]\n", "\n", "df_design = pd.DataFrame(np.array(design[\"data\"]).transpose(), columns = np.array(design[\"names\"].asstr()))\n", "df_results = pd.DataFrame(np.array(results[\"data\"]).transpose(), columns = np.array(results[\"names\"].asstr()))\n", "\n", "data_file.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Preprocess Data\n", "\n", "The data are preprocessed in the following way:\n", "\n", "1. Label data points in the `design` dataset with `reactive` and `non-reactive` labels using kmeans clustering\n", "2. Transform `design` and `results` data set into log-scaled data.\n", "3. Split data into training and test sets.\n", "4. Learn scaler on training data for `design` and `results` together (option `global`) or individual (option `individual`).\n", "5. Transform training and test data.\n", "6. Split training data into training and validation dataset." ] }, { "cell_type": "code", "execution_count": 182, "metadata": {}, "outputs": [], "source": [ "species_columns = [\"H\", \"O\", \"Charge\", \"Ba\", \"Cl\", \"S_6_\", \"Sr\", \"Barite\", \"Celestite\"]" ] }, { "cell_type": "code", "execution_count": 183, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/signer/bin/miniconda3/envs/training/lib/python3.11/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (2). Possibly due to duplicate points in X.\n", " return fit_method(estimator, *args, **kwargs)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Amount class 0 before: 0.9879169719169719\n", "Amount class 1 before: 0.012083028083028084\n" ] } ], "source": [ "X_train, X_val, X_test, y_train, y_val, y_test, scaler_X, scaler_y = preprocessing_training(df_design[species_columns], df_results[species_columns], func_dict_in, func_dict_out, \"off\", 'global', 0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Custom Loss function" ] }, { "cell_type": "code", "execution_count": 146, "metadata": {}, "outputs": [], "source": [ "def custom_loss_H20(df_design_log, df_result_log, data_min_log, data_max_log, func_dict_out, postprocess):\n", " df_result = postprocess(df_result_log, func_dict_out, data_min_log, data_max_log) \n", " return keras.losses.Huber + np.sum(((df_result['H'] / df_result['O']) - 2)**2)\n", "\n", "def loss_wrapper(data_min_log, data_max_log, func_dict_out, postprocess):\n", " def loss(df_design_log, df_result_log):\n", " return custom_loss_H20(df_design_log, df_result_log, data_min_log, data_max_log, func_dict_out, postprocess)\n", " return loss" ] }, { "cell_type": "code", "execution_count": 199, "metadata": {}, "outputs": [], "source": [ "def custom_loss(scaler_X, scaler_y, FuncTransform, dict_in, dict_out, columns):\n", " def loss(results, predicted):\n", " \n", " predicted = pd.DataFrame(scaler_X.inverse_transform(predicted), columns = columns)\n", " results = pd.DataFrame(scaler_y.inverse_transform(results), columns = columns)\n", " \n", " predicted = FuncTransform(dict_in, dict_out).inverse_transform(predicted)\n", " results = FuncTransform(dict_in, dict_out).inverse_transform(results)\n", " \n", " dBa = tf.keras.backend.abs((predicted[\"Ba\"] + predicted[\"Barite\"]) - (results[\"Ba\"] + results[\"Barite\"]))\n", " dSr = tf.keras.backend.abs((predicted[\"Sr\"] + predicted[\"Celestite\"]) - (results[\"Sr\"] + results[\"Celestite\"]))\n", " total_loss = keras.loss.Huber(results, predicted) + 0.1 * dBa + 0.1 * dSr\n", " \n", " return total_loss\n", "\n", " return loss" ] }, { "cell_type": "code", "execution_count": 200, "metadata": {}, "outputs": [], "source": [ "model_simple.compile(optimizer=optimizer_simple, loss=custom_loss(scaler_X, scaler_y, FuncTransform, func_dict_in, func_dict_out, species_columns))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train the model" ] }, { "cell_type": "code", "execution_count": 188, "metadata": {}, "outputs": [], "source": [ "# measure time\n", "def model_training(model):\n", " start = time.time()\n", " callback = keras.callbacks.EarlyStopping(monitor='loss', patience=3)\n", " history = model.fit(X_train.iloc[:, X_train.columns != \"Class\"], \n", " y_train.iloc[:, y_train.columns != \"Class\"], \n", " batch_size = batch_size, \n", " epochs = 20, \n", " validation_data = (X_val.iloc[:, X_val.columns != \"Class\"], y_val.iloc[:, y_val.columns != \"Class\"]),\n", " callbacks = [callback])\n", "\n", " end = time.time()\n", "\n", " print(\"Training took {} seconds\".format(end - start))" ] }, { "cell_type": "code", "execution_count": 203, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | H | \n", "O | \n", "Charge | \n", "H_0_ | \n", "O_0_ | \n", "Ba | \n", "Cl | \n", "S_2_ | \n", "S_6_ | \n", "Sr | \n", "Barite | \n", "Celestite | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "111.012434 | \n", "55.506219 | \n", "-7.748371e-09 | \n", "1.541375e-13 | \n", "0.000000e+00 | \n", "1.916106e-02 | \n", "4.082138e-02 | \n", "1.440295e-20 | \n", "6.403132e-07 | \n", "0.001250 | \n", "0.002250 | \n", "0.999244 | \n", "
| 1 | \n", "111.012434 | \n", "55.506220 | \n", "-4.672647e-09 | \n", "0.000000e+00 | \n", "6.197818e-14 | \n", "9.637903e-03 | \n", "2.189133e-02 | \n", "0.000000e+00 | \n", "7.449873e-07 | \n", "0.001309 | \n", "0.002308 | \n", "0.999244 | \n", "
| 2 | \n", "111.012434 | \n", "55.508129 | \n", "-1.325997e-09 | \n", "1.416456e-13 | \n", "0.000000e+00 | \n", "5.839037e-07 | \n", "5.887491e-04 | \n", "2.981540e-17 | \n", "4.781732e-04 | \n", "0.000772 | \n", "0.001294 | \n", "0.999847 | \n", "
| 3 | \n", "111.012434 | \n", "55.508676 | \n", "-1.223549e-09 | \n", "4.449984e-14 | \n", "0.000000e+00 | \n", "4.536952e-07 | \n", "2.183551e-05 | \n", "3.803886e-19 | \n", "6.148996e-04 | \n", "0.000625 | \n", "0.001010 | \n", "0.999995 | \n", "
| 4 | \n", "111.012434 | \n", "55.508699 | \n", "-1.216518e-09 | \n", "0.000000e+00 | \n", "4.098861e-14 | \n", "4.496372e-07 | \n", "8.098352e-07 | \n", "0.000000e+00 | \n", "6.205560e-04 | \n", "0.000621 | \n", "0.001000 | \n", "1.000000 | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 2502495 | \n", "111.012434 | \n", "55.507488 | \n", "3.573897e-09 | \n", "0.000000e+00 | \n", "1.374167e-10 | \n", "9.953042e-07 | \n", "2.266555e-03 | \n", "0.000000e+00 | \n", "3.178142e-04 | \n", "0.001450 | \n", "0.001000 | \n", "1.000014 | \n", "
| 2502496 | \n", "111.012434 | \n", "55.507501 | \n", "3.494199e-09 | \n", "0.000000e+00 | \n", "1.378928e-10 | \n", "9.814307e-07 | \n", "2.217997e-03 | \n", "0.000000e+00 | \n", "3.210423e-04 | \n", "0.001429 | \n", "0.001000 | \n", "1.000010 | \n", "
| 2502497 | \n", "111.012434 | \n", "55.507512 | \n", "3.429947e-09 | \n", "0.000000e+00 | \n", "1.376072e-10 | \n", "9.704342e-07 | \n", "2.179066e-03 | \n", "0.000000e+00 | \n", "3.236905e-04 | \n", "0.001412 | \n", "0.001000 | \n", "1.000006 | \n", "
| 2502498 | \n", "111.012434 | \n", "55.507520 | \n", "3.381818e-09 | \n", "0.000000e+00 | \n", "1.368903e-10 | \n", "9.632999e-07 | \n", "2.149820e-03 | \n", "0.000000e+00 | \n", "3.257170e-04 | \n", "0.001400 | \n", "0.001000 | \n", "1.000004 | \n", "
| 2502499 | \n", "111.012434 | \n", "55.507525 | \n", "3.349044e-09 | \n", "0.000000e+00 | \n", "1.378174e-10 | \n", "9.563975e-07 | \n", "2.129912e-03 | \n", "0.000000e+00 | \n", "3.271123e-04 | \n", "0.001391 | \n", "0.001000 | \n", "1.000001 | \n", "
2502500 rows × 12 columns
\n", "| \n", " | H | \n", "O | \n", "Charge | \n", "H_0_ | \n", "O_0_ | \n", "Ba | \n", "Cl | \n", "S_2_ | \n", "S_6_ | \n", "Sr | \n", "Barite | \n", "Celestite | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "111.012434 | \n", "55.508192 | \n", "-7.779554e-09 | \n", "2.697041e-26 | \n", "2.210590e-15 | \n", "2.041069e-02 | \n", "4.082138e-02 | \n", "0.000000e+00 | \n", "0.000494 | \n", "0.000494 | \n", "0.001 | \n", "1.000000 | \n", "
| 1 | \n", "111.012434 | \n", "55.508427 | \n", "-4.736083e-09 | \n", "1.446346e-26 | \n", "2.473481e-15 | \n", "1.094567e-02 | \n", "2.189133e-02 | \n", "0.000000e+00 | \n", "0.000553 | \n", "0.000553 | \n", "0.001 | \n", "1.000000 | \n", "
| 2 | \n", "111.012434 | \n", "55.508691 | \n", "-1.311169e-09 | \n", "3.889826e-28 | \n", "2.769320e-15 | \n", "2.943745e-04 | \n", "5.887491e-04 | \n", "0.000000e+00 | \n", "0.000619 | \n", "0.000619 | \n", "0.001 | \n", "1.000000 | \n", "
| 3 | \n", "111.012434 | \n", "55.508698 | \n", "-1.220023e-09 | \n", "1.442658e-29 | \n", "2.777193e-15 | \n", "1.091776e-05 | \n", "2.183551e-05 | \n", "0.000000e+00 | \n", "0.000620 | \n", "0.000620 | \n", "0.001 | \n", "1.000000 | \n", "
| 4 | \n", "111.012434 | \n", "55.508699 | \n", "-1.216643e-09 | \n", "5.350528e-31 | \n", "2.777485e-15 | \n", "4.049176e-07 | \n", "8.098352e-07 | \n", "0.000000e+00 | \n", "0.000620 | \n", "0.000620 | \n", "0.001 | \n", "1.000000 | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 2502495 | \n", "111.012434 | \n", "55.507488 | \n", "3.573728e-09 | \n", "5.424062e-145 | \n", "1.375204e-10 | \n", "9.953520e-07 | \n", "2.266555e-03 | \n", "5.509534e-149 | \n", "0.000318 | \n", "0.001450 | \n", "0.001 | \n", "1.000014 | \n", "
| 2502496 | \n", "111.012434 | \n", "55.507501 | \n", "3.494007e-09 | \n", "2.011675e-146 | \n", "1.377139e-10 | \n", "9.817216e-07 | \n", "2.217997e-03 | \n", "2.043375e-150 | \n", "0.000321 | \n", "0.001429 | \n", "0.001 | \n", "1.000010 | \n", "
| 2502497 | \n", "111.012434 | \n", "55.507512 | \n", "3.429764e-09 | \n", "7.460897e-148 | \n", "1.377819e-10 | \n", "9.706451e-07 | \n", "2.179066e-03 | \n", "7.578467e-152 | \n", "0.000324 | \n", "0.001412 | \n", "0.001 | \n", "1.000006 | \n", "
| 2502498 | \n", "111.012434 | \n", "55.507520 | \n", "3.381745e-09 | \n", "2.767237e-149 | \n", "1.371144e-10 | \n", "9.621074e-07 | \n", "2.149820e-03 | \n", "2.810844e-153 | \n", "0.000326 | \n", "0.001400 | \n", "0.001 | \n", "1.000004 | \n", "
| 2502499 | \n", "111.012434 | \n", "55.507525 | \n", "3.348864e-09 | \n", "5.321610e-151 | \n", "1.376026e-10 | \n", "9.564401e-07 | \n", "2.129912e-03 | \n", "5.405468e-155 | \n", "0.000327 | \n", "0.001391 | \n", "0.001 | \n", "1.000001 | \n", "
2502500 rows × 12 columns
\n", "| \n", " | H | \n", "O | \n", "Charge | \n", "H_0_ | \n", "O_0_ | \n", "Ba | \n", "Cl | \n", "S_2_ | \n", "S_6_ | \n", "Sr | \n", "Barite | \n", "Celestite | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "111.012434 | \n", "55.508700 | \n", "-1.216415e-09 | \n", "0.0 | \n", "2.217711e-13 | \n", "4.495355e-07 | \n", "1.532249e-12 | \n", "0.0 | \n", "0.000621 | \n", "0.000620 | \n", "0.001000 | \n", "1.000000 | \n", "
| 1 | \n", "111.012434 | \n", "55.508700 | \n", "-1.222504e-09 | \n", "0.0 | \n", "1.312902e-12 | \n", "4.500359e-07 | \n", "1.044603e-08 | \n", "0.0 | \n", "0.000621 | \n", "0.000620 | \n", "0.001000 | \n", "1.000000 | \n", "
| 2 | \n", "111.012434 | \n", "55.508699 | \n", "-1.220407e-09 | \n", "0.0 | \n", "1.614098e-12 | \n", "4.500563e-07 | \n", "4.907802e-07 | \n", "0.0 | \n", "0.000621 | \n", "0.000620 | \n", "0.001000 | \n", "1.000000 | \n", "
| 3 | \n", "111.012434 | \n", "55.508695 | \n", "-1.216831e-09 | \n", "0.0 | \n", "2.293739e-12 | \n", "4.504482e-07 | \n", "4.772370e-06 | \n", "0.0 | \n", "0.000620 | \n", "0.000622 | \n", "0.001000 | \n", "1.000000 | \n", "
| 4 | \n", "111.012434 | \n", "55.508679 | \n", "-1.216842e-09 | \n", "0.0 | \n", "2.641545e-12 | \n", "4.534070e-07 | \n", "2.200220e-05 | \n", "0.0 | \n", "0.000616 | \n", "0.000626 | \n", "0.001000 | \n", "1.000000 | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 995 | \n", "111.012434 | \n", "55.506410 | \n", "2.170844e-08 | \n", "0.0 | \n", "4.777415e-11 | \n", "1.550089e-04 | \n", "9.784038e-02 | \n", "0.0 | \n", "0.000048 | \n", "0.048814 | \n", "0.149476 | \n", "0.846327 | \n", "
| 996 | \n", "111.012434 | \n", "55.506410 | \n", "2.166504e-08 | \n", "0.0 | \n", "4.808612e-11 | \n", "1.559012e-04 | \n", "9.785750e-02 | \n", "0.0 | \n", "0.000048 | \n", "0.048821 | \n", "0.151708 | \n", "0.844093 | \n", "
| 997 | \n", "111.012434 | \n", "55.506409 | \n", "2.162167e-08 | \n", "0.0 | \n", "4.811205e-11 | \n", "1.567226e-04 | \n", "9.787459e-02 | \n", "0.0 | \n", "0.000048 | \n", "0.048829 | \n", "0.153945 | \n", "0.841856 | \n", "
| 998 | \n", "111.012434 | \n", "55.506409 | \n", "2.157995e-08 | \n", "0.0 | \n", "4.815004e-11 | \n", "1.574812e-04 | \n", "9.789167e-02 | \n", "0.0 | \n", "0.000048 | \n", "0.048836 | \n", "0.156185 | \n", "0.839614 | \n", "
| 999 | \n", "111.012434 | \n", "55.506409 | \n", "2.153938e-08 | \n", "0.0 | \n", "4.815067e-11 | \n", "1.581835e-04 | \n", "9.790872e-02 | \n", "0.0 | \n", "0.000048 | \n", "0.048844 | \n", "0.158428 | \n", "0.837370 | \n", "
1000 rows × 12 columns
\n", "| \n", " | H | \n", "O | \n", "Charge | \n", "H_0_ | \n", "O_0_ | \n", "Ba | \n", "Cl | \n", "S_2_ | \n", "S_6_ | \n", "Sr | \n", "Barite | \n", "Celestite | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "111.012428 | \n", "55.508682 | \n", "-1.193141e-09 | \n", "3.397941e-15 | \n", "2.128961e-13 | \n", "-0.000012 | \n", "0.000021 | \n", "-1.191799e-17 | \n", "0.000620 | \n", "0.000630 | \n", "0.000985 | \n", "1.000231 | \n", "
| 1 | \n", "111.012428 | \n", "55.508682 | \n", "-1.202185e-09 | \n", "2.918474e-15 | \n", "1.019832e-12 | \n", "-0.000016 | \n", "0.000008 | \n", "-7.920033e-18 | \n", "0.000620 | \n", "0.000630 | \n", "0.000946 | \n", "1.000121 | \n", "
| 2 | \n", "111.012428 | \n", "55.508682 | \n", "-1.203471e-09 | \n", "1.785468e-15 | \n", "2.398433e-12 | \n", "-0.000015 | \n", "-0.000013 | \n", "-9.158671e-18 | \n", "0.000620 | \n", "0.000627 | \n", "0.000913 | \n", "0.999977 | \n", "
| 3 | \n", "111.012428 | \n", "55.508682 | \n", "-1.199235e-09 | \n", "1.746077e-15 | \n", "2.357316e-12 | \n", "-0.000016 | \n", "-0.000011 | \n", "-9.246642e-18 | \n", "0.000619 | \n", "0.000629 | \n", "0.000916 | \n", "0.999936 | \n", "
| 4 | \n", "111.012428 | \n", "55.508682 | \n", "-1.197043e-09 | \n", "1.533956e-15 | \n", "2.670053e-12 | \n", "-0.000019 | \n", "0.000003 | \n", "-9.144569e-18 | \n", "0.000615 | \n", "0.000635 | \n", "0.000943 | \n", "0.999769 | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 995 | \n", "111.012428 | \n", "55.506416 | \n", "2.189995e-08 | \n", "-3.792428e-15 | \n", "4.785985e-11 | \n", "0.000279 | \n", "0.097642 | \n", "1.182626e-16 | \n", "0.000051 | \n", "0.048738 | \n", "0.149016 | \n", "0.844585 | \n", "
| 996 | \n", "111.012428 | \n", "55.506416 | \n", "2.188781e-08 | \n", "-3.910681e-15 | \n", "4.822730e-11 | \n", "0.000279 | \n", "0.097665 | \n", "1.169111e-16 | \n", "0.000051 | \n", "0.048751 | \n", "0.151241 | \n", "0.842272 | \n", "
| 997 | \n", "111.012428 | \n", "55.506416 | \n", "2.184875e-08 | \n", "-3.749360e-15 | \n", "4.824932e-11 | \n", "0.000279 | \n", "0.097687 | \n", "1.146696e-16 | \n", "0.000051 | \n", "0.048761 | \n", "0.153477 | \n", "0.840015 | \n", "
| 998 | \n", "111.012428 | \n", "55.506416 | \n", "2.180415e-08 | \n", "-3.500642e-15 | \n", "4.816323e-11 | \n", "0.000279 | \n", "0.097710 | \n", "1.119656e-16 | \n", "0.000051 | \n", "0.048770 | \n", "0.155720 | \n", "0.837773 | \n", "
| 999 | \n", "111.012428 | \n", "55.506416 | \n", "2.177183e-08 | \n", "-3.375572e-15 | \n", "4.822685e-11 | \n", "0.000279 | \n", "0.097732 | \n", "1.095921e-16 | \n", "0.000050 | \n", "0.048780 | \n", "0.157962 | \n", "0.835504 | \n", "
1000 rows × 12 columns
\n", "