model-training/POET_Training.ipynb
2025-02-11 17:17:43 +01:00

1142 lines
139 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## General Information"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook is used to train a simple neural network model to predict the chemistry in the barite benchmark (50x50 grid). The training data is stored in the repository using **git large file storage** and can be downloaded after the installation of git lfs using the `git lfs pull` command.\n",
"\n",
"It is then recommended to create a Python environment using miniconda. The necessary dependencies are contained in `environment.yml` and can be installed using `conda env create -f environment.yml`.\n",
"\n",
"The data set is divided into a design and result part and consists of the iterations of a reference simulation. The design part of the data set contains the chemical concentrations at time $t$ and the result part at time $t+1$, which are to be learned by the model."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup Libraries"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2025-02-11 15:40:37.253319: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
"2025-02-11 15:40:37.275142: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running Keras in version 3.6.0\n"
]
}
],
"source": [
"import keras\n",
"import h5py\n",
"import numpy as np\n",
"import pandas as pd\n",
"import time\n",
"import sklearn.model_selection as sk\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.cluster import KMeans\n",
"from sklearn.pipeline import Pipeline, make_pipeline\n",
"from sklearn.preprocessing import StandardScaler, MinMaxScaler\n",
"from imblearn.over_sampling import SMOTE\n",
"from imblearn.under_sampling import RandomUnderSampler\n",
"from imblearn.over_sampling import RandomOverSampler\n",
"from collections import Counter\n",
"import os\n",
"from preprocessing import *\n",
"from sklearn import set_config\n",
"set_config(transform_output = \"pandas\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define parameters"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
"outputs": [],
"source": [
"dtype = \"float32\"\n",
"activation = \"relu\"\n",
"\n",
"lr = 0.001\n",
"batch_size = 512\n",
"epochs = 50 # default 400 epochs\n",
"\n",
"lr_schedule = keras.optimizers.schedules.ExponentialDecay(\n",
" initial_learning_rate=lr,\n",
" decay_steps=2000,\n",
" decay_rate=0.9,\n",
" staircase=True\n",
")\n",
"\n",
"optimizer_simple = keras.optimizers.Adam(learning_rate=lr_schedule)\n",
"optimizer_large = keras.optimizers.Adam(learning_rate=lr_schedule)\n",
"optimizer_paper = keras.optimizers.Adam(learning_rate=lr_schedule)\n",
"\n",
"\n",
"loss = keras.losses.MeanSquaredError()\n",
"\n",
"sample_fraction = 0.8"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup the model"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential_4\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential_4\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_17 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,664</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_18 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">16,512</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_19 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,548</span> │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_17 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m1,664\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_18 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m16,512\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_19 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m1,548\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">19,724</span> (77.05 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m19,724\u001b[0m (77.05 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">19,724</span> (77.05 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m19,724\u001b[0m (77.05 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# small model\n",
"model_simple = keras.Sequential(\n",
" [\n",
" keras.Input(shape = (12,), dtype = \"float32\"),\n",
" keras.layers.Dense(units = 128, activation = \"relu\", dtype = \"float32\"),\n",
" keras.layers.Dense(units = 128, activation = \"relu\", dtype = \"float32\"),\n",
" keras.layers.Dense(units = 12, dtype = \"float32\")\n",
" ]\n",
")\n",
"\n",
"model_simple.compile(optimizer=optimizer_simple, loss = loss)\n",
"model_simple.summary()"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential_5\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential_5\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_20 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">512</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">6,656</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_21 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1024</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">525,312</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_22 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">512</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">524,800</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_23 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">6,156</span> │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_20 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m6,656\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_21 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1024\u001b[0m) │ \u001b[38;5;34m525,312\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_22 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m524,800\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_23 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m6,156\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,062,924</span> (4.05 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,062,924\u001b[0m (4.05 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,062,924</span> (4.05 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,062,924\u001b[0m (4.05 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# large model\n",
"model_large = keras.Sequential(\n",
" [keras.layers.Input(shape=(12,), dtype=dtype),\n",
" keras.layers.Dense(512, activation='relu', dtype=dtype),\n",
" keras.layers.Dense(1024, activation='relu', dtype=dtype),\n",
" keras.layers.Dense(512, activation='relu', dtype=dtype),\n",
" keras.layers.Dense(12, dtype=dtype)\n",
" ])\n",
"\n",
"model_large.compile(optimizer=optimizer_large, loss = loss)\n",
"model_large.summary()\n"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential_6\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential_6\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_24 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,664</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_25 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">33,024</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_26 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">512</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">131,584</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_27 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">131,328</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_28 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">3,084</span> │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_24 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m1,664\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_25 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m33,024\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_26 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m131,584\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_27 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m131,328\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_28 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m3,084\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">300,684</span> (1.15 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m300,684\u001b[0m (1.15 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">300,684</span> (1.15 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m300,684\u001b[0m (1.15 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# model from paper\n",
"# (see https://doi.org/10.1007/s11242-022-01779-3 model for the complex chemistry)\n",
"model_paper = keras.Sequential(\n",
" [keras.layers.Input(shape=(12,), dtype=dtype),\n",
" keras.layers.Dense(128, activation='relu', dtype=dtype),\n",
" keras.layers.Dense(256, activation='relu', dtype=dtype),\n",
" keras.layers.Dense(512, activation='relu', dtype=dtype),\n",
" keras.layers.Dense(256, activation='relu', dtype=dtype),\n",
" keras.layers.Dense(12, dtype=dtype)\n",
" ])\n",
"\n",
"model_paper.compile(optimizer=optimizer_paper, loss = loss)\n",
"model_paper.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define transformer functions"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [],
"source": [
"def Safelog(val):\n",
" # get range of vector\n",
" if val > 0:\n",
" return np.log10(val)\n",
" elif val < 0:\n",
" return -np.log10(-val)\n",
" else:\n",
" return 0\n",
"\n",
"def Safeexp(val):\n",
" if val > 0:\n",
" return -10 ** -val\n",
" elif val < 0:\n",
" return 10 ** val\n",
" else:\n",
" return 0"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [],
"source": [
"# ? Why does the charge is using another logarithm than the other species\n",
"\n",
"func_dict_in = {\n",
" \"H\" : np.log1p,\n",
" \"O\" : np.log1p,\n",
" \"Charge\" : Safelog,\n",
" \"H_0_\" : np.log1p,\n",
" \"O_0_\" : np.log1p,\n",
" \"Ba\" : np.log1p,\n",
" \"Cl\" : np.log1p,\n",
" \"S_2_\" : np.log1p,\n",
" \"S_6_\" : np.log1p,\n",
" \"Sr\" : np.log1p,\n",
" \"Barite\" : np.log1p,\n",
" \"Celestite\" : np.log1p,\n",
"}\n",
"\n",
"func_dict_out = {\n",
" \"H\" : np.expm1,\n",
" \"O\" : np.expm1,\n",
" \"Charge\" : Safeexp,\n",
" \"H_0_\" : np.expm1,\n",
" \"O_0_\" : np.expm1,\n",
" \"Ba\" : np.expm1,\n",
" \"Cl\" : np.expm1,\n",
" \"S_2_\" : np.expm1,\n",
" \"S_6_\" : np.expm1,\n",
" \"Sr\" : np.expm1,\n",
" \"Barite\" : np.expm1,\n",
" \"Celestite\" : np.expm1,\n",
"}\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Read data from `.h5` file and convert it to a `pandas.DataFrame`"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [],
"source": [
"# os.chdir('/mnt/beegfs/home/signer/projects/model-training')\n",
"data_file = h5py.File(\"Barite_50_Data_training.h5\")\n",
"\n",
"design = data_file[\"design\"]\n",
"results = data_file[\"result\"]\n",
"\n",
"df_design = pd.DataFrame(np.array(design[\"data\"]).transpose(), columns = design[\"names\"].asstr())\n",
"df_results = pd.DataFrame(np.array(results[\"data\"]).transpose(), columns = results[\"names\"].asstr())\n",
"\n",
"data_file.close()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Preprocess Data\n",
"\n",
"The data are preprocessed in the following way:\n",
"\n",
"1. Label data points in the `design` dataset with `reactive` and `non-reactive` labels using kmeans clustering\n",
"2. Transform `design` and `results` data set into log-scaled data.\n",
"3. Split data into training and test sets.\n",
"4. Learn scaler on training data for `design` and `results` together (option `global`) or individual (option `individual`).\n",
"5. Transform training and test data.\n",
"6. Split training data into training and validation dataset."
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/signer/bin/miniconda3/envs/training/lib/python3.11/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (2). Possibly due to duplicate points in X.\n",
" return fit_method(estimator, *args, **kwargs)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Amount class 0 before: 0.9879169719169719\n",
"Amount class 1 before: 0.012083028083028084\n"
]
}
],
"source": [
"X_train, X_val, X_test, y_train, y_val, y_test, scaler_X, scaler_y = preprocessing_training(df_design, df_results, func_dict_in, func_dict_out, \"off\", 'individual', 0.1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Custom Loss function"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [],
"source": [
"def custom_loss_H20(df_design_log, df_result_log, data_min_log, data_max_log, func_dict_out, postprocess):\n",
" df_result = postprocess(df_result_log, func_dict_out, data_min_log, data_max_log) \n",
" return keras.losses.Huber + np.sum(((df_result['H'] / df_result['O']) - 2)**2)\n",
"\n",
"def loss_wrapper(data_min_log, data_max_log, func_dict_out, postprocess):\n",
" def loss(df_design_log, df_result_log):\n",
" return custom_loss_H20(df_design_log, df_result_log, data_min_log, data_max_log, func_dict_out, postprocess)\n",
" return loss"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the model"
]
},
{
"cell_type": "code",
"execution_count": 117,
"metadata": {},
"outputs": [],
"source": [
"# measure time\n",
"def model_training(model):\n",
" start = time.time()\n",
" callback = keras.callbacks.EarlyStopping(monitor='loss', patience=3)\n",
" history = model.fit(X_train.iloc[:, X_train.columns != \"Class\"], \n",
" y_train.iloc[:, y_train.columns != \"Class\"], \n",
" batch_size = batch_size, \n",
" epochs = 40, \n",
" validation_data = (X_val.iloc[:, X_val.columns != \"Class\"], y_val.iloc[:, y_val.columns != \"Class\"]),\n",
" callbacks = [callback])\n",
"\n",
" end = time.time()\n",
"\n",
" print(\"Training took {} seconds\".format(end - start))"
]
},
{
"cell_type": "code",
"execution_count": 118,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 1.8876e-06 - val_loss: 1.8770e-06\n",
"Epoch 2/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.9907e-06 - val_loss: 1.8834e-06\n",
"Epoch 3/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.6876e-06 - val_loss: 1.8772e-06\n",
"Epoch 4/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.5167e-06 - val_loss: 1.8927e-06\n",
"Epoch 5/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 2.0424e-06 - val_loss: 1.8783e-06\n",
"Epoch 6/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 1.5759e-06 - val_loss: 1.8775e-06\n",
"Epoch 7/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.7919e-06 - val_loss: 1.8742e-06\n",
"Epoch 8/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.9016e-06 - val_loss: 1.8740e-06\n",
"Epoch 9/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.7540e-06 - val_loss: 1.8810e-06\n",
"Epoch 10/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.9548e-06 - val_loss: 1.8773e-06\n",
"Epoch 11/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 2.0289e-06 - val_loss: 1.8712e-06\n",
"Epoch 12/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.6964e-06 - val_loss: 1.8717e-06\n",
"Epoch 13/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 2.1222e-06 - val_loss: 1.8696e-06\n",
"Epoch 14/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.9142e-06 - val_loss: 1.8694e-06\n",
"Epoch 15/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.6718e-06 - val_loss: 1.8697e-06\n",
"Epoch 16/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.5208e-06 - val_loss: 1.8694e-06\n",
"Epoch 17/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.8481e-06 - val_loss: 1.8691e-06\n",
"Epoch 18/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.6943e-06 - val_loss: 1.8693e-06\n",
"Epoch 19/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.8389e-06 - val_loss: 1.8690e-06\n",
"Epoch 20/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.9054e-06 - val_loss: 1.8691e-06\n",
"Epoch 21/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.6646e-06 - val_loss: 1.8702e-06\n",
"Epoch 22/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.8724e-06 - val_loss: 1.8692e-06\n",
"Epoch 23/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.7165e-06 - val_loss: 1.8690e-06\n",
"Epoch 24/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.8717e-06 - val_loss: 1.8690e-06\n",
"Epoch 25/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.8996e-06 - val_loss: 1.8690e-06\n",
"Epoch 26/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.8146e-06 - val_loss: 1.8691e-06\n",
"Epoch 27/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.9670e-06 - val_loss: 1.8689e-06\n",
"Epoch 28/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.7449e-06 - val_loss: 1.8690e-06\n",
"Epoch 29/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 2.0738e-06 - val_loss: 1.8690e-06\n",
"Epoch 30/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 2ms/step - loss: 1.5405e-06 - val_loss: 1.8690e-06\n",
"Epoch 31/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.6470e-06 - val_loss: 1.8692e-06\n",
"Epoch 32/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.7679e-06 - val_loss: 1.8690e-06\n",
"Epoch 33/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.7011e-06 - val_loss: 1.8690e-06\n",
"Epoch 34/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 2.0628e-06 - val_loss: 1.8690e-06\n",
"Epoch 35/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.7559e-06 - val_loss: 1.8690e-06\n",
"Epoch 36/40\n",
"\u001b[1m3960/3960\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 1.8712e-06 - val_loss: 1.8690e-06\n",
"Training took 245.66379165649414 seconds\n"
]
}
],
"source": [
"model_training(model_simple)"
]
},
{
"cell_type": "code",
"execution_count": 102,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 14ms/step\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7f43e0c41750>"
]
},
"execution_count": 102,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"species = \"Ba\"\n",
"iterations = 10\n",
"cell_offset = 50\n",
"y_design = []\n",
"y_results = []\n",
"for i in range(0,iterations):\n",
" idx = i*50*50 + cell_offset -1\n",
" y_design.append(df_design.iloc[idx, :])\n",
" y_results.append(df_results.iloc[idx,:])\n",
" \n",
"y_design = pd.DataFrame(y_design)\n",
"y_results = pd.DataFrame(y_results)\n",
"plt.plot(np.arange(0,iterations), y_design[species], label = \"Design\")\n",
"plt.plot(np.arange(0,iterations), y_results[species], label = \"Results\")\n",
"\n",
"y = FuncTransform(func_dict_in, func_dict_out).fit_transform(y_design)\n",
"y = scaler_X.transform(y)\n",
"prediction = model_large.predict(y.iloc[:, y.columns != \"Class\"])\n",
"prediction = pd.DataFrame(prediction, columns = y.columns[y.columns != \"Class\"])\n",
"prediction_back = pd.DataFrame(scaler_X.inverse_transform(prediction), columns=prediction.columns)\n",
"prediction_back = FuncTransform(func_dict_out, func_dict_in).inverse_transform(prediction_back.iloc[:, prediction_back.columns != \"Class\"])\n",
"\n",
"plt.plot(np.arange(0,iterations), prediction_back[species], label = \"Prediction\")\n",
"plt.xlabel('Iteration')\n",
"plt.ylabel(species)\n",
"plt.title(species+' Concentration over Iterations')\n",
"plt.show()\n",
"\n",
"timestep = 1000\n",
"plt.imshow(np.array(df_results[species][(timestep*2500):(timestep*2500+2500)]).reshape(50,50), origin='lower')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(history.history[\"loss\"], \"o-\", label = \"Training Loss\")\n",
"plt.xlabel(\"Epoch\")\n",
"# plt.yscale('log')\n",
"plt.ylabel(\"Loss (Huber)\")\n",
"plt.grid('on')\n",
"\n",
"plt.savefig(\"loss_all.png\", dpi=300)\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(history.history[\"loss\"][1:], \"o-\", label = \"Training Loss\")\n",
"plt.xlabel(\"Epoch\")\n",
"# plt.yscale('log')\n",
"plt.ylabel(\"Loss (Huber)\")\n",
"plt.grid('on')\n",
"plt.savefig(\"loss_1_to_end.png\", dpi=300)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test the model"
]
},
{
"cell_type": "code",
"execution_count": 103,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m7821/7821\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 354us/step - loss: 2.4332e-06\n"
]
},
{
"data": {
"text/plain": [
"2.221618160547223e-06"
]
},
"execution_count": 103,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# test on all test data\n",
"model_simple.evaluate(X_test.iloc[:,X_test.columns != \"Class\"], y_test.iloc[:, y_test.columns != \"Class\"])"
]
},
{
"cell_type": "code",
"execution_count": 104,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m7727/7727\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 309us/step - loss: 1.3519e-06\n"
]
},
{
"data": {
"text/plain": [
"1.1662884844554355e-06"
]
},
"execution_count": 104,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# test on non-reactive data\n",
"model_simple.evaluate(X_test[X_test['Class'] == 0].iloc[:,:-1], y_test[X_test['Class'] == 0].iloc[:,:-1])"
]
},
{
"cell_type": "code",
"execution_count": 105,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m94/94\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 347us/step - loss: 8.9877e-05\n"
]
},
{
"data": {
"text/plain": [
"8.905206777853891e-05"
]
},
"execution_count": 105,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# test on reactive data\n",
"model_simple.evaluate(X_test[X_test['Class'] == 1].iloc[:,:-1], y_test[X_test['Class'] == 1].iloc[:, :-1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test Mass Balance"
]
},
{
"cell_type": "code",
"execution_count": 106,
"metadata": {},
"outputs": [],
"source": [
"def mass_balance(model, X, scaler_X, func_dict_in, func_dict_out):\n",
" # backtransform min/max\n",
" columns = X.iloc[:, X.columns != \"Class\"].columns\n",
" X = pd.DataFrame(scaler_X.inverse_transform(X.iloc[:, X.columns != \"Class\"]), columns = columns)\n",
" \n",
" # backtransform log\n",
" X = FuncTransform(func_dict_in, func_dict_out).inverse_transform(X)\n",
" \n",
" # calculate mass balance\n",
" predicton = model.predict(X)\n",
" dBa = np.abs((prediction[\"Ba\"] + prediction[\"Barite\"]) - (X[\"Ba\"] + X[\"Barite\"]))\n",
" dSr = np.abs((prediction[\"Sr\"] + prediction[\"Celestite\"]) - (X[\"Sr\"] + X[\"Celestite\"]))\n",
" \n",
" return dBa + dSr"
]
},
{
"cell_type": "code",
"execution_count": 107,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m7821/7821\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 345us/step\n"
]
}
],
"source": [
"columns = X_test.columns[X_test.columns != \"Class\"]\n",
"prediction = pd.DataFrame(model_simple.predict(X_test[columns]), columns = columns)"
]
},
{
"cell_type": "code",
"execution_count": 108,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m7821/7821\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 357us/step\n"
]
}
],
"source": [
"mass_balance = mass_balance(model_simple, X_test, scaler_X, func_dict_in, func_dict_out)"
]
},
{
"cell_type": "code",
"execution_count": 116,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1163 0.000526\n",
"1175 0.000948\n",
"2425 0.000854\n",
"3231 0.000993\n",
"3964 0.000145\n",
" ... \n",
"243165 0.000928\n",
"248138 0.000698\n",
"248832 0.000915\n",
"248963 0.000696\n",
"249253 0.000932\n",
"Length: 115, dtype: float64"
]
},
"execution_count": 116,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mass_balance[mass_balance < 0.001]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Save the model"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"# Save the model\n",
"model.save(\"Barite_50_Model_additional_species.keras\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Legacy Code"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"def log_scale(df_design, df_result, func_dict):\n",
" \n",
" df_design = df_design.copy()\n",
" df_result = df_result.copy()\n",
" \n",
" for key in df_design.keys():\n",
" if key != \"Class\":\n",
" df_design[key] = np.vectorize(func_dict[key])(df_design[key])\n",
" df_result[key] = np.vectorize(func_dict[key])(df_result[key])\n",
" \n",
" return df_design, df_result\n",
"\n",
"# Get minimum and maximum values for each column\n",
"def get_min_max(df_design, df_result):\n",
" \n",
" min_vals_des = df_design.min()\n",
" max_vals_des = df_design.max()\n",
" \n",
" min_vals_res = df_result.min()\n",
" max_vals_res = df_result.max()\n",
"\n",
" # minimum of input and output data to get global minimum/maximum\n",
" data_min = np.minimum(min_vals_des, min_vals_res).to_dict()\n",
" data_max = np.maximum(max_vals_des, max_vals_res).to_dict()\n",
"\n",
" return data_min, data_max\n",
"\n",
"df_design_log, df_results_log = log_scale(df_design, df_results, func_dict_in)\n",
"data_min_log, data_max_log = get_min_max(df_design_log, df_design_log)\n",
"\n",
"train_min_log, train_max_log = get_min_max(X_train_log, y_train_log)\n",
"test_min_log, test_max_log = get_min_max(X_test_log, y_test_log)\n",
"\n",
"X_train_preprocess = preprocess(X_train_log, func_dict_in, train_min_log, train_max_log)\n",
"y_train_preprocess = preprocess(y_train_log, func_dict_in, train_min_log, train_max_log)\n",
"\n",
"X_test_preprocess = preprocess(X_test_log, func_dict_in, test_min_log, test_max_log)\n",
"y_test_preprocess = preprocess(y_test_log, func_dict_in, test_min_log, test_max_log)\n",
"\n",
"X_train_log, y_train_log = log_scale(X_train, y_train, func_dict_in)\n",
"X_test_log, y_test_log = log_scale(X_test, y_test, func_dict_in)\n",
"\n",
"\n",
"def preprocess(data, func_dict, data_min, data_max):\n",
" data = data.copy()\n",
" for key in data.keys():\n",
" if key != \"Class\":\n",
" data[key] = (data[key] - data_min[key]) / (data_max[key] - data_min[key])\n",
"\n",
" return data\n",
"\n",
"def postprocess(data, func_dict, data_min, data_max):\n",
" data = data.copy()\n",
" for key in data.keys():\n",
" if key != \"Class\":\n",
" data[key] = data[key] * (data_max[key] - data_min[key]) + data_min[key]\n",
" data[key] = np.vectorize(func_dict[key])(data[key])\n",
" return data\n",
"\n",
"X_train, X_val, y_train, y_val = sk.train_test_split(X_train_preprocess, y_train_preprocess, test_size = 0.1)\n",
"\n",
"pp_design = preprocess(df_design_log, func_dict_in, data_min_log, data_max_log)\n",
"pp_results = preprocess(df_results_log, func_dict_in, data_min_log, data_max_log)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "training",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}