model-training/POET_Training.ipynb

1532 lines
163 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## General Information"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook is used to train a simple neural network model to predict the chemistry in the barite benchmark (50x50 grid). The training data is stored in the repository using **git large file storage** and can be downloaded after the installation of git lfs using the `git lfs pull` command.\n",
"\n",
"It is then recommended to create a Python environment using miniconda. The necessary dependencies are contained in `environment.yml` and can be installed using `conda env create -f environment.yml`.\n",
"\n",
"The data set is divided into a design and result part and consists of the iterations of a reference simulation. The design part of the data set contains the chemical concentrations at time $t$ and the result part at time $t+1$, which are to be learned by the model."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup Libraries"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2025-01-23 14:37:53.766781: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
"2025-01-23 14:37:53.786741: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running Keras in version 3.8.0\n"
]
}
],
"source": [
"import keras\n",
"import h5py\n",
"import numpy as np\n",
"import pandas as pd\n",
"import time\n",
"import sklearn.model_selection as sk\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.cluster import KMeans\n",
"from sklearn.pipeline import Pipeline, make_pipeline\n",
"from sklearn.preprocessing import StandardScaler, MinMaxScaler\n",
"from imblearn.over_sampling import SMOTE\n",
"from imblearn.under_sampling import RandomUnderSampler\n",
"from imblearn.over_sampling import RandomOverSampler\n",
"from collections import Counter\n",
"import os\n",
"from preprocessing import *\n",
"from sklearn import set_config\n",
"from importlib import reload\n",
"set_config(transform_output = \"pandas\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define parameters"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"dtype = \"float32\"\n",
"activation = \"relu\"\n",
"\n",
"lr = 0.001\n",
"batch_size = 512\n",
"epochs = 50 # default 400 epochs\n",
"\n",
"lr_schedule = keras.optimizers.schedules.ExponentialDecay(\n",
" initial_learning_rate=lr,\n",
" decay_steps=2000,\n",
" decay_rate=0.9,\n",
" staircase=True\n",
")\n",
"\n",
"optimizer_simple = keras.optimizers.Adam(learning_rate=lr_schedule)\n",
"optimizer_large = keras.optimizers.Adam(learning_rate=lr_schedule)\n",
"optimizer_paper = keras.optimizers.Adam(learning_rate=lr_schedule)\n",
"\n",
"\n",
"loss = keras.losses.MeanSquaredError()\n",
"\n",
"sample_fraction = 0.8"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup the model"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential_2\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential_2\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_7 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,664</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_8 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">16,512</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_9 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,548</span> │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_7 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m1,664\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_8 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m16,512\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_9 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m1,548\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">19,724</span> (77.05 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m19,724\u001b[0m (77.05 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">19,724</span> (77.05 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m19,724\u001b[0m (77.05 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# small model\n",
"model_simple = keras.Sequential(\n",
" [\n",
" keras.Input(shape = (12,), dtype = \"float32\"),\n",
" keras.layers.Dense(units = 128, activation = \"relu\", dtype = \"float32\"),\n",
" keras.layers.Dense(units = 128, activation = \"relu\", dtype = \"float32\"),\n",
" keras.layers.Dense(units = 12, dtype = \"float32\")\n",
" ]\n",
")\n",
"\n",
"model_simple.compile(optimizer=optimizer_simple, loss = loss)\n",
"model_simple.summary()"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential_5\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential_5\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_20 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">512</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">6,656</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_21 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1024</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">525,312</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_22 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">512</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">524,800</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_23 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">6,156</span> │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_20 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m6,656\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_21 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1024\u001b[0m) │ \u001b[38;5;34m525,312\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_22 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m524,800\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_23 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m6,156\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,062,924</span> (4.05 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,062,924\u001b[0m (4.05 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,062,924</span> (4.05 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,062,924\u001b[0m (4.05 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# large model\n",
"model_large = keras.Sequential(\n",
" [keras.layers.Input(shape=(12,), dtype=dtype),\n",
" keras.layers.Dense(512, activation='relu', dtype=dtype),\n",
" keras.layers.Dense(1024, activation='relu', dtype=dtype),\n",
" keras.layers.Dense(512, activation='relu', dtype=dtype),\n",
" keras.layers.Dense(12, dtype=dtype)\n",
" ])\n",
"\n",
"model_large.compile(optimizer=optimizer_large, loss = loss)\n",
"model_large.summary()\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential_6\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential_6\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_24 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,664</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_25 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">33,024</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_26 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">512</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">131,584</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_27 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">131,328</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_28 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">3,084</span> │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_24 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m1,664\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_25 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m33,024\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_26 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m131,584\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_27 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m131,328\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_28 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m3,084\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">300,684</span> (1.15 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m300,684\u001b[0m (1.15 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">300,684</span> (1.15 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m300,684\u001b[0m (1.15 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# model from paper\n",
"# (see https://doi.org/10.1007/s11242-022-01779-3 model for the complex chemistry)\n",
"model_paper = keras.Sequential(\n",
" [keras.layers.Input(shape=(12,), dtype=dtype),\n",
" keras.layers.Dense(128, activation='relu', dtype=dtype),\n",
" keras.layers.Dense(256, activation='relu', dtype=dtype),\n",
" keras.layers.Dense(512, activation='relu', dtype=dtype),\n",
" keras.layers.Dense(256, activation='relu', dtype=dtype),\n",
" keras.layers.Dense(12, dtype=dtype)\n",
" ])\n",
"\n",
"model_paper.compile(optimizer=optimizer_paper, loss = loss)\n",
"model_paper.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define transformer functions"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def Safelog(val):\n",
" # get range of vector\n",
" if val > 0:\n",
" return np.log10(val)\n",
" elif val < 0:\n",
" return -np.log10(-val)\n",
" else:\n",
" return 0\n",
"\n",
"def Safeexp(val):\n",
" if val > 0:\n",
" return -10 ** -val\n",
" elif val < 0:\n",
" return 10 ** val\n",
" else:\n",
" return 0"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# ? Why does the charge is using another logarithm than the other species\n",
"\n",
"func_dict_in = {\n",
" \"H\" : np.log1p,\n",
" \"O\" : np.log1p,\n",
" \"Charge\" : Safelog,\n",
" \"H_0_\" : np.log1p,\n",
" \"O_0_\" : np.log1p,\n",
" \"Ba\" : np.log1p,\n",
" \"Cl\" : np.log1p,\n",
" \"S_2_\" : np.log1p,\n",
" \"S_6_\" : np.log1p,\n",
" \"Sr\" : np.log1p,\n",
" \"Barite\" : np.log1p,\n",
" \"Celestite\" : np.log1p,\n",
"}\n",
"\n",
"func_dict_out = {\n",
" \"H\" : np.expm1,\n",
" \"O\" : np.expm1,\n",
" \"Charge\" : Safeexp,\n",
" \"H_0_\" : np.expm1,\n",
" \"O_0_\" : np.expm1,\n",
" \"Ba\" : np.expm1,\n",
" \"Cl\" : np.expm1,\n",
" \"S_2_\" : np.expm1,\n",
" \"S_6_\" : np.expm1,\n",
" \"Sr\" : np.expm1,\n",
" \"Barite\" : np.expm1,\n",
" \"Celestite\" : np.expm1,\n",
"}\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Read data from `.h5` file and convert it to a `pandas.DataFrame`"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"# os.chdir('/mnt/beegfs/home/signer/projects/model-training')\n",
"data_file = h5py.File(\"Barite_50_Data_training.h5\")\n",
"\n",
"design = data_file[\"design\"]\n",
"results = data_file[\"result\"]\n",
"\n",
"df_design = pd.DataFrame(np.array(design[\"data\"]).transpose(), columns = np.array(design[\"names\"].asstr()))\n",
"df_results = pd.DataFrame(np.array(results[\"data\"]).transpose(), columns = np.array(results[\"names\"].asstr()))\n",
"\n",
"data_file.close()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Preprocess Data\n",
"\n",
"The data are preprocessed in the following way:\n",
"\n",
"1. Label data points in the `design` dataset with `reactive` and `non-reactive` labels using kmeans clustering\n",
"2. Transform `design` and `results` data set into log-scaled data.\n",
"3. Split data into training and test sets.\n",
"4. Learn scaler on training data for `design` and `results` together (option `global`) or individual (option `individual`).\n",
"5. Transform training and test data.\n",
"6. Split training data into training and validation dataset."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/hannessigner/miniforge3/envs/ai/lib/python3.12/site-packages/sklearn/base.py:1474: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (2). Possibly due to duplicate points in X.\n",
" return fit_method(estimator, *args, **kwargs)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Amount class 0 before: 0.9879169719169719\n",
"Amount class 1 before: 0.012083028083028084\n"
]
}
],
"source": [
"X_train, X_val, X_test, y_train, y_val, y_test, scaler_X, scaler_y = preprocessing_training(df_design, df_results, func_dict_in, func_dict_out, \"over\", 'individual', 0.1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Custom Loss function"
]
},
{
"cell_type": "code",
"execution_count": 164,
"metadata": {},
"outputs": [],
"source": [
"def custom_loss_H20(df_design_log, df_result_log, data_min_log, data_max_log, func_dict_out, postprocess):\n",
" df_result = postprocess(df_result_log, func_dict_out, data_min_log, data_max_log) \n",
" return keras.losses.Huber + np.sum(((df_result['H'] / df_result['O']) - 2)**2)\n",
"\n",
"def loss_wrapper(data_min_log, data_max_log, func_dict_out, postprocess):\n",
" def loss(df_design_log, df_result_log):\n",
" return custom_loss_H20(df_design_log, df_result_log, data_min_log, data_max_log, func_dict_out, postprocess)\n",
" return loss"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the model"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 2ms/step - loss: 0.0018 - val_loss: 3.6601e-05\n",
"Epoch 2/20\n",
"\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 2ms/step - loss: 3.6899e-05 - val_loss: 3.6822e-05\n",
"Epoch 3/20\n",
"\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 2ms/step - loss: 3.5005e-05 - val_loss: 3.5655e-05\n",
"Epoch 4/20\n",
"\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 2ms/step - loss: 3.4032e-05 - val_loss: 3.3455e-05\n",
"Epoch 5/20\n",
"\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 2ms/step - loss: 3.3279e-05 - val_loss: 3.3064e-05\n",
"Epoch 6/20\n",
"\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 2ms/step - loss: 3.3023e-05 - val_loss: 3.3338e-05\n",
"Epoch 7/20\n",
"\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 2ms/step - loss: 3.2532e-05 - val_loss: 3.2765e-05\n",
"Epoch 8/20\n",
"\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 2ms/step - loss: 3.2749e-05 - val_loss: 3.2730e-05\n",
"Epoch 9/20\n",
"\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 2ms/step - loss: 3.2961e-05 - val_loss: 3.2593e-05\n",
"Epoch 10/20\n",
"\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 2ms/step - loss: 3.2573e-05 - val_loss: 3.2576e-05\n",
"Epoch 11/20\n",
"\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 2ms/step - loss: 3.2442e-05 - val_loss: 3.2507e-05\n",
"Epoch 12/20\n",
"\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 2ms/step - loss: 3.2135e-05 - val_loss: 3.2548e-05\n",
"Epoch 13/20\n",
"\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 2ms/step - loss: 3.2451e-05 - val_loss: 3.2482e-05\n",
"Epoch 14/20\n",
"\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 2ms/step - loss: 3.2296e-05 - val_loss: 3.2475e-05\n",
"Epoch 15/20\n",
"\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 2ms/step - loss: 3.2081e-05 - val_loss: 3.2470e-05\n",
"Epoch 16/20\n",
"\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 2ms/step - loss: 3.2440e-05 - val_loss: 3.2471e-05\n",
"Epoch 17/20\n",
"\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 2ms/step - loss: 3.2050e-05 - val_loss: 3.2460e-05\n",
"Epoch 18/20\n",
"\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 2ms/step - loss: 3.2444e-05 - val_loss: 3.2452e-05\n",
"Epoch 19/20\n",
"\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 2ms/step - loss: 3.2259e-05 - val_loss: 3.2452e-05\n",
"Epoch 20/20\n",
"\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 2ms/step - loss: 3.2442e-05 - val_loss: 3.2448e-05\n",
"Training took 276.5459449291229 seconds\n"
]
}
],
"source": [
"model_training(model_simple)"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step \n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGsCAYAAAB968WXAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAJJVJREFUeJzt3XtU1HXi//HXyGUUZabQEBA0UpMMUVfTUEszNc3KatduhppWa5ul+au8VF+rXYVqa7fditR2Oe3pwh5Tyzorm61CF6+orKSVlpom4KVwIExQeP/+cJ2aQHIA3zD6fJzD+crn8573vOet7Ty/M58BhzHGCAAAwJJmjb0AAABwdiE+AACAVcQHAACwivgAAABWER8AAMAq4gMAAFhFfAAAAKuIDwAAYBXxAQAArCI+AACAVU06Pj788ENde+21iomJkcPh0Ntvv33a73Pv3r26/fbb1bp1a4WFhalHjx7asGFDneebM2eO+vXrp7CwMJ1zzjmndJvFixfrqquuUps2beRwOJSXl1dtzPz58zVo0CC5XC45HA4dOnSo2pji4mKlpKTI7XbL7XYrJSWl2rgpU6aoV69ecjqd6tGjR7U5srOzNWrUKEVHR6tly5bq0aOHXn/99ZOu/ZNPPlFwcHCNc9Vm165dmjhxouLj49WiRQt17NhRs2fPVkVFhV/zAACaviYdH2VlZerevbteeOEFK/dXXFys/v37KyQkRMuWLdPWrVv17LPP1hoN559/vrKzs096vqKiQqNHj9Y999xzyusoKytT//79lZaWdtIxhw8f1vDhwzVr1qyTjrntttuUl5enrKwsZWVlKS8vTykpKT5jjDGaMGGCbr755hrnWLVqlZKSkrRo0SJt3rxZEyZM0NixY/Xuu+9WG+vxeDR27FhdeeWVp/hIf/T555+rqqpK8+bN05YtW/SnP/1JL7/8cq2PDwAQoEyAkGSWLFnic6y8vNw89NBDJiYmxoSFhZk+ffqYlStX1vk+pk+fbgYMGODXbTp06HBK95mRkWHcbrdfc+/cudNIMps2bTrpmJUrVxpJpri42Of41q1bjSSzZs0a77HVq1cbSebzzz+vNs/s2bNN9+7dT2ldV199tbnjjjuqHb/55pvNo48+etK5/v73v5uEhATjdDpNly5dzIsvvljr/Tz99NMmPj7+lNYEAAgcTfqVj19yxx136JNPPlFmZqY2b96s0aNHa/jw4dq+fXud5lu6dKl69+6t0aNHKzIyUj179tSCBQsaeNV2rF69Wm63W3379vUeu/TSS+V2u7Vq1ap6ze3xeBQREeFzLCMjQ1999ZVmz55d420WLFigRx55RHPmzNFnn32muXPn6rHHHtOrr77q1/0AAAJfwMbHV199pTfffFMLFy7UZZddpo4dO+rBBx/UgAEDlJGRUac5d+zYofT0dHXu3Fn//ve/NWnSJN1///36xz/+0cCrP/2KiooUGRlZ7XhkZKSKiorqPO9bb72l9evX64477vAe2759u2bMmKHXX39dwcHBNd7u97//vZ599lndeOONio+P14033qgHHnhA8+bNq3H8V199pb/+9a+aNGlSndcKAGiaan6mCAAbN26UMUYXXnihz/Hy8nK1bt1a0vGLGOPj42ud59577/VeU1JVVaXevXtr7ty5kqSePXtqy5YtSk9P19ixYyVJkyZN0muvvea9/eHDhzVixAgFBQV5j23dulXt27ev/4OsJ4fDUe2YMabG46ciOztb48eP14IFC3TxxRdLkiorK3XbbbfpiSeeqPZ3ccKBAwe0Z88eTZw4UXfddZf3+LFjx+R2u6uNLygo0PDhwzV69GjdeeeddVorAKDpCtj4qKqqUlBQkDZs2ODzxC9JrVq1kiS1a9dOn332Wa3znHvuud4/R0dHq2vXrj7nL7roIi1atMj7/ZNPPqkHH3zQ+/2gQYP01FNP+by9ERMT4/8DamBRUVHat29fteMHDhxQ27Zt/Z4vJydH1157rZ577jlviElSaWmpcnNztWnTJk2ePFnS8b8bY4yCg4P1/vvve0NlwYIFPvskqdrfXUFBga644golJydr/vz5fq8TAND0BWx89OzZU5WVldq/f78uu+yyGseEhIQoISHhlOfs37+/vvjiC59j27ZtU4cOHbzfR0ZG+rydERwcrHbt2qlTp05+PoLTKzk5WR6PR+vWrVOfPn0kSWvXrpXH41G/fv38mis7O1vXXHONnnrqKd19990+51wul/Lz832OvfTSS1qxYoXeeustxcfHq2XLlmrXrp127NihMWPGnPR+9u7dqyuuuEK9evVSRkaGmjUL2HcFAQC1aNLx8f333+vLL7/0fr9z507l5eUpIiJCF154ocaMGaOxY8fq2WefVc+ePXXw4EGtWLFC3bp109VXX+33/T3wwAPq16+f5s6dq5tuuknr1q3T/Pnz6/X/ge/evVvfffeddu/ercrKSu/P7OjUqZP3FZqEhASlpqbqhhtukCTv+IKCAknyBlFUVJSioqIkHb+mo6ioyLs/+fn5Cg8PV/v27RUREaGLLrpIw4cP11133eW9ruLuu+/WNddcoy5dunjX9+WXX+r7779XUVGRfvjhB+/6unbtqtDQUGVnZ2vkyJGaMmWKfv3rX3uvFwkNDVVERISaNWumxMREn8ccGRmp5s2b+xx//PHHdf/998vlcmnEiBEqLy9Xbm6uiouLNW3aNBUUFGjQoEFq3769/vjHP+rAgQPe2554zACAM0Qjf9qmVic+Rvrzr3HjxhljjKmoqDD/93//Z84//3wTEhJioqKizA033GA2b95c5/t89913TWJionE6nSYhIcHMnz+/1vG/9FHbcePG1fgYfnobSSYjI8P7fUZGRo23mT17tnfM7Nmzaxzz03m+/fZbM2bMGBMeHm7Cw8PNmDFjqn0kd+DAgTXOs3PnzlrXP3DgwJM+5pN91Pb11183PXr0MKGhoebcc881l19+uVm8eHGtj7mJ/xMFANSBwxhjTmvdAAAA/ARvqgMAAKuIDwAAYFWTu+C0qqpKBQUFCg8Pr/PPowAAAHYZY1RaWqqYmJhf/LRik4uPgoICxcXFNfYyAABAHezZs0exsbG1jmly8REeHi7p+OJdLlcjrwYAAJyKkpISxcXFeZ/Ha9Pk4uPEWy0ul4v4AAAgwJzKJRNccAoAAKwiPgAAgFXEBwAAsIr4AAAAVhEfAADAKuIDAABYRXwAAACriA8AAGAV8QEAAKwiPgAAgFXEBwAAsIr4AAAAVjW5Xyx3uhljtPNgmb7+9rAOV1Tqh6OVOlpZpSpjVGWOn6+qOv7nKmNk/vd/T3xf27w//vln56qN/ek5U8u5k9+wrnPWOG8jqGUr7a2hKewES5Dk+99P462hsVfQVP4uGnsFTeS/zTNccDOHHhnZtfHuv9HuuRFUHKvSxFfX66PtBxt7KQAANJrQ4GbEhy3/zN2jj7YflMMhXRTlUqvmwWoREqSQoGYKaiY1czjUzOGQw3Hiz/rf98f/7HBIDv34q4J//luDq/8W4VrG1nK7Wu/D53Yn/7XFtc1Z81rtawJLaBr70AQW0fgrUJNYxM//O2mUNTT+EprALjSNfTiTBTVr3Ksuzqr4WPXl8Vc8/t/QCzV5cOdGXg0AAGens+qC0+37v5ckdYs9p3EXAgDAWeysiY+KY1XadbBMktQ5slUjrwYAgLPXWfO2S+mRo7qscxvtPfSDot3NG3s5AACctc6a+GjdyqmMO/o09jIAADjrnTVvuwAAgKaB+AAAAFYRHwAAwCriAwAAWEV8AAAAq4gPAABgFfEBAACsIj4AAIBVxAcAALCK+AAAAFYRHwAAwCriAwAAWEV8AAAAq4gPAABgFfEBAACsIj4AAIBVxAcAALCK+AAAAFYRHwAAwCriAwAAWEV8AAAAq4gPAABgFfEBAACsIj4AAIBVxAcAALCK+AAAAFYRHwAAwCriAwAAWEV8AAAAq4gPAABgFfEBAACsIj4AAIBVfsVHenq6kpKS5HK55HK5lJycrGXLlnnPP/7440pISFDLli117rnnasiQIVq7dm2DLxoAAAQuv+IjNjZWaWlpys3NVW5urgYPHqxRo0Zpy5YtkqQLL7xQL7zwgvLz8/Xxxx/r/PPP17Bhw3TgwIHTsngAABB4HMYYU58JIiIi9Mwzz2jixInVzpWUlMjtduuDDz7QlVdeeUrznbiNx+ORy+Wqz9IAAIAl/jx/B9f1TiorK7Vw4UKVlZUpOTm52vmKigrNnz9fbrdb3bt3P+k85eXlKi8v91k8AAA4c/kdH/n5+UpOTtaRI0fUqlUrLVmyRF27dvWef++993TLLbfo8OHDio6O1vLly9WmTZuTzpeamqonnniibqsHAAABx++3XSoqKrR7924dOnRIixYt0iuvvKKcnBxvgJSVlamwsFAHDx7UggULtGLFCq1du1aRkZE1zlfTKx9xcXG87QIAQADx522Xel/zMWTIEHXs2FHz5s2r8Xznzp01YcIEzZw585Tm45oPAAACjz/P3/X+OR/GGJ9XLvw9DwAAzi5+XfMxa9YsjRgxQnFxcSotLVVmZqays7OVlZWlsrIyzZkzR9ddd52io6P17bff6qWXXtI333yj0aNHn671AwCAAONXfOzbt08pKSkqLCyU2+1WUlKSsrKyNHToUB05ckSff/65Xn31VR08eFCtW7fWJZdcoo8++kgXX3zx6Vo/AAAIMPW+5qOhcc0HAACBx+o1HwAAAP4gPgAAgFXEBwAAsIr4AAAAVhEfAADAKuIDAABYRXwAAACriA8AAGAV8QEAAKwiPgAAgFXEBwAAsIr4AAAAVhEfAADAKuIDAABYRXwAAACriA8AAGAV8QEAAKwiPgAAgFXEBwAAsIr4AAAAVhEfAADAKuIDAABYRXwAAACriA8AAGAV8QEAAKwiPgAAgFXEBwAAsIr4AAAAVhEfAADAKuIDAABYRXwAAACriA8AAGAV8QEAAKwiPgAAgFXEBwAAsIr4AAAAVhEfAADAKuIDAABYRXwAAACriA8AAGAV8QEAAKwiPgAAgFXEBwAAsIr4AAAAVhEfAADAKuIDAABYRXwAAACriA8AAGAV8QEAAKwiPgAAgFXEBwAAsMqv+EhPT1dSUpJcLpdcLpeSk5O1bNkySdLRo0c1ffp0devWTS1btlRMTIzGjh2rgoKC07JwAAAQmPyKj9jYWKWlpSk3N1e5ubkaPHiwRo0apS1btujw4cPauHGjHnvsMW3cuFGLFy/Wtm3bdN11152utQMAgADkMMaY+kwQERGhZ555RhMnTqx2bv369erTp4++/vprtW/f/pTmKykpkdvtlsfjkcvlqs/SAACAJf48fwfX9U4qKyu1cOFClZWVKTk5ucYxHo9HDodD55xzzknnKS8vV3l5uff7kpKSui4JAAAEAL8vOM3Pz1erVq3kdDo1adIkLVmyRF27dq027siRI5oxY4Zuu+22WgsoNTVVbrfb+xUXF+fvkgAAQADx+22XiooK7d69W4cOHdKiRYv0yiuvKCcnxydAjh49qtGjR2v37t3Kzs6uNT5qeuUjLi6Ot10AAAgg/rztUu9rPoYMGaKOHTtq3rx5ko6Hx0033aQdO3ZoxYoVat26tV/zcc0HAACBx8o1HycYY7yvXJwIj+3bt2vlypV+hwcAADjz+RUfs2bN0ogRIxQXF6fS0lJlZmYqOztbWVlZOnbsmH7zm99o48aNeu+991RZWamioiJJxz8RExoaeloeAAAACCx+xce+ffuUkpKiwsJCud1uJSUlKSsrS0OHDtWuXbu0dOlSSVKPHj18brdy5UoNGjSoodYMAAACWL2v+WhoXPMBAEDg8ef5m9/tAgAArCI+AACAVcQHAACwivgAAABWER8AAMAq4gMAAFhFfAAAAKuIDwAAYBXxAQAArCI+AACAVcQHAACwivgAAABWER8AAMAq4gMAAFhFfAAAAKuIDwAAYBXxAQAArCI+AACAVcQHAACwivgAAABWER8AAMAq4gMAAFhFfAAAAKuIDwAAYBXxAQAArCI+AACAVcQHAACwivgAAABWER8AAMAq4gMAAFhFfAAAAKuIDwAAYBXxAQAArCI+AACAVcQHAACwivgAAABWER8AAMAq4gMAAFhFfAAAAKuIDwAAYBXxAQAArCI+AACAVcQHAACwivgAAABWER8AAMAq4gMAAFhFfAAAAKuIDwAAYBXxAQAArCI+AACAVcQHAACwyq/4SE9PV1JSklwul1wul5KTk7Vs2TLv+cWLF+uqq65SmzZt5HA4lJeX19DrBQAAAc6v+IiNjVVaWppyc3OVm5urwYMHa9SoUdqyZYskqaysTP3791daWtppWSwAAAh8DmOMqc8EEREReuaZZzRx4kTvsV27dik+Pl6bNm1Sjx49/JqvpKREbrdbHo9HLperPksDAACW+PP8HVzXO6msrNTChQtVVlam5OTkuk6j8vJylZeXe78vKSmp81wAAKDp8/uC0/z8fLVq1UpOp1OTJk3SkiVL1LVr1zovIDU1VW632/sVFxdX57kAAEDT53d8dOnSRXl5eVqzZo3uuecejRs3Tlu3bq3zAmbOnCmPx+P92rNnT53nAgAATZ/fb7uEhoaqU6dOkqTevXtr/fr1ev755zVv3rw6LcDpdMrpdNbptgAAIPDU++d8GGN8rtkAAACojV+vfMyaNUsjRoxQXFycSktLlZmZqezsbGVlZUmSvvvuO+3evVsFBQWSpC+++EKSFBUVpaioqAZeOgAACER+vfKxb98+paSkqEuXLrryyiu1du1aZWVlaejQoZKkpUuXqmfPnho5cqQk6ZZbblHPnj318ssvN/zKAQBAQKr3z/loaPycDwAAAo8/z9/8bhcAAGAV8QEAAKwiPgAAgFXEBwAAsIr4AAAAVhEfAADAKuIDAABYRXwAAACriA8AAGAV8QEAAKwiPgAAgFXEBwAAsIr4AAAAVhEfAADAKuIDAABYRXwAAACriA8AAGAV8QEAAKwiPgAAgFXEBwAAsIr4AAAAVhEfAADAKuIDAABYRXwAAACriA8AAGAV8QEAAKwiPgAAgFXEBwAAsIr4AAAAVhEfAADAKuIDAABYRXwAAACriA8AAGAV8QEAAKwiPgAAgFXEBwAAsIr4AAAAVhEfAADAKuIDAABYRXwAAACriA8AAGAV8QEAAKwiPgAAgFXEBwAAsIr4AAAAVhEfAADAKuIDAABYRXwAAACriA8AAGAV8QEAAKwiPgAAgFXEBwAAsMqv+EhPT1dSUpJcLpdcLpeSk5O1bNky73ljjB5//HHFxMSoRYsWGjRokLZs2dLgiwYAAIHLr/iIjY1VWlqacnNzlZubq8GDB2vUqFHewHj66af13HPP6YUXXtD69esVFRWloUOHqrS09LQsHgAABB6HMcbUZ4KIiAg988wzmjBhgmJiYjR16lRNnz5dklReXq62bdvqqaee0m9/+9tTmq+kpERut1sej0cul6s+SwMAAJb48/xd52s+KisrlZmZqbKyMiUnJ2vnzp0qKirSsGHDvGOcTqcGDhyoVatWnXSe8vJylZSU+HwBAIAzl9/xkZ+fr1atWsnpdGrSpElasmSJunbtqqKiIklS27Ztfca3bdvWe64mqampcrvd3q+4uDh/lwQAAAKI3/HRpUsX5eXlac2aNbrnnns0btw4bd261Xve4XD4jDfGVDv2UzNnzpTH4/F+7dmzx98lAQCAABLs7w1CQ0PVqVMnSVLv3r21fv16Pf/8897rPIqKihQdHe0dv3///mqvhvyU0+mU0+n0dxkAACBA1fvnfBhjVF5ervj4eEVFRWn58uXecxUVFcrJyVG/fv3qezcAAOAM4dcrH7NmzdKIESMUFxen0tJSZWZmKjs7W1lZWXI4HJo6darmzp2rzp07q3Pnzpo7d67CwsJ02223na71AwCAAONXfOzbt08pKSkqLCyU2+1WUlKSsrKyNHToUEnSww8/rB9++EG/+93vVFxcrL59++r9999XeHj4aVk8AAAIPPX+OR8NjZ/zAQBA4LHycz4AAADqgvgAAABWER8AAMAq4gMAAFhFfAAAAKuIDwAAYBXxAQAArCI+AACAVcQHAACwivgAAABWER8AAMAq4gMAAFhFfAAAAKuIDwAAYBXxAQAArCI+AACAVcQHAACwivgAAABWER8AAMAq4gMAAFhFfAAAAKuIDwAAYBXxAQAArCI+AACAVcQHAACwivgAAABWER8AAMAq4gMAAFhFfAAAAKuIDwAAYBXxAQAArCI+AACAVcQHAACwivgAAABWER8AAMAq4gMAAFhFfAAAAKuIDwAAYBXxAQAArCI+AACAVcQHAACwivgAAABWER8AAMAq4gMAAFhFfAAAAKuIDwAAYBXxAQAArCI+AACAVcQHAACwivgAAABWER8AAMAqv+IjNTVVl1xyicLDwxUZGanrr79eX3zxhc+Yffv2afz48YqJiVFYWJiGDx+u7du3N+iiAQBA4PIrPnJycnTvvfdqzZo1Wr58uY4dO6Zhw4aprKxMkmSM0fXXX68dO3bonXfe0aZNm9ShQwcNGTLEOwYAAJzdHMYYU9cbHzhwQJGRkcrJydHll1+ubdu2qUuXLvr000918cUXS5IqKysVGRmpp556SnfeeecvzllSUiK32y2PxyOXy1XXpQEAAIv8ef6u1zUfHo9HkhQRESFJKi8vlyQ1b97cOyYoKEihoaH6+OOPa5yjvLxcJSUlPl8AAODMVef4MMZo2rRpGjBggBITEyVJCQkJ6tChg2bOnKni4mJVVFQoLS1NRUVFKiwsrHGe1NRUud1u71dcXFxdlwQAAAJAneNj8uTJ2rx5s958803vsZCQEC1atEjbtm1TRESEwsLClJ2drREjRigoKKjGeWbOnCmPx+P92rNnT12XBAAAAkBwXW503333aenSpfrwww8VGxvrc65Xr17Ky8uTx+NRRUWFzjvvPPXt21e9e/eucS6n0ymn01mXZQAAgADk1ysfxhhNnjxZixcv1ooVKxQfH3/SsW63W+edd562b9+u3NxcjRo1qt6LBQAAgc+vVz7uvfdevfHGG3rnnXcUHh6uoqIiScdDo0WLFpKkhQsX6rzzzlP79u2Vn5+vKVOm6Prrr9ewYcMafvUAACDg+BUf6enpkqRBgwb5HM/IyND48eMlSYWFhZo2bZr27dun6OhojR07Vo899liDLBYAAAS+ev2cj9OBn/MBAEDgsfZzPgAAAPxFfAAAAKuIDwAAYBXxAQAArCI+AACAVcQHAACwivgAAABWER8AAMAq4gMAAFhFfAAAAKuIDwAAYBXxAQAArCI+AACAVcQHAACwivgAAABWER8AAMAq4gMAAFhFfAAAAKuIDwAAYBXxAQAArCI+AACAVcQHAACwivgAAABWER8AAMAq4gMAAFhFfAAAAKuIDwAAYBXxAQAArCI+AACAVcQHAACwivgAAABWER8AAMAq4gMAAFhFfAAAAKuIDwAAYFVwYy/AGmOko4cbexUAADQNIWGSw9Eod332xMfRw9LcmMZeBQAATcOsAim0ZaPcNW+7AAAAq86eVz5Cwo5XHgAAOP682EjOnvhwOBrt5SUAAPAj3nYBAABWER8AAMAq4gMAAFhFfAAAAKuIDwAAYBXxAQAArCI+AACAVcQHAACwivgAAABWER8AAMAq4gMAAFhFfAAAAKuIDwAAYFWT+622xhhJUklJSSOvBAAAnKoTz9snnsdr0+Tio7S0VJIUFxfXyCsBAAD+Ki0tldvtrnWMw5xKolhUVVWlgoIChYeHy+FwNOjcJSUliouL0549e+RyuRp0bvyIfbaHvbaDfbaDfbbjdO2zMUalpaWKiYlRs2a1X9XR5F75aNasmWJjY0/rfbhcLv5hW8A+28Ne28E+28E+23E69vmXXvE4gQtOAQCAVcQHAACw6qyKD6fTqdmzZ8vpdDb2Us5o7LM97LUd7LMd7LMdTWGfm9wFpwAA4Mx2Vr3yAQAAGh/xAQAArCI+AACAVcQHAACw6qyJj5deeknx8fFq3ry5evXqpY8++qixlxRQUlNTdckllyg8PFyRkZG6/vrr9cUXX/iMMcbo8ccfV0xMjFq0aKFBgwZpy5YtPmPKy8t13333qU2bNmrZsqWuu+46ffPNNzYfSkBJTU2Vw+HQ1KlTvcfY54azd+9e3X777WrdurXCwsLUo0cPbdiwwXueva6/Y8eO6dFHH1V8fLxatGihCy64QE8++aSqqqq8Y9hn/3344Ye69tprFRMTI4fDobffftvnfEPtaXFxsVJSUuR2u+V2u5WSkqJDhw7V/wGYs0BmZqYJCQkxCxYsMFu3bjVTpkwxLVu2NF9//XVjLy1gXHXVVSYjI8N8+umnJi8vz4wcOdK0b9/efP/9994xaWlpJjw83CxatMjk5+ebm2++2URHR5uSkhLvmEmTJpl27dqZ5cuXm40bN5orrrjCdO/e3Rw7dqwxHlaTtm7dOnP++eebpKQkM2XKFO9x9rlhfPfdd6ZDhw5m/PjxZu3atWbnzp3mgw8+MF9++aV3DHtdf3/4wx9M69atzXvvvWd27txpFi5caFq1amX+/Oc/e8ewz/7717/+ZR555BGzaNEiI8ksWbLE53xD7enw4cNNYmKiWbVqlVm1apVJTEw011xzTb3Xf1bER58+fcykSZN8jiUkJJgZM2Y00ooC3/79+40kk5OTY4wxpqqqykRFRZm0tDTvmCNHjhi3221efvllY4wxhw4dMiEhISYzM9M7Zu/evaZZs2YmKyvL7gNo4kpLS03nzp3N8uXLzcCBA73xwT43nOnTp5sBAwac9Dx73TBGjhxpJkyY4HPsxhtvNLfffrsxhn1uCD+Pj4ba061btxpJZs2aNd4xq1evNpLM559/Xq81n/Fvu1RUVGjDhg0aNmyYz/Fhw4Zp1apVjbSqwOfxeCRJERERkqSdO3eqqKjIZ5+dTqcGDhzo3ecNGzbo6NGjPmNiYmKUmJjI38XP3HvvvRo5cqSGDBnic5x9bjhLly5V7969NXr0aEVGRqpnz55asGCB9zx73TAGDBig//znP9q2bZsk6b///a8+/vhjXX311ZLY59OhofZ09erVcrvd6tu3r3fMpZdeKrfbXe99b3K/WK6hHTx4UJWVlWrbtq3P8bZt26qoqKiRVhXYjDGaNm2aBgwYoMTEREny7mVN+/z11197x4SGhurcc8+tNoa/ix9lZmZq48aNWr9+fbVz7HPD2bFjh9LT0zVt2jTNmjVL69at0/333y+n06mxY8ey1w1k+vTp8ng8SkhIUFBQkCorKzVnzhzdeuutkvg3fTo01J4WFRUpMjKy2vyRkZH13vczPj5OcDgcPt8bY6odw6mZPHmyNm/erI8//rjaubrsM38XP9qzZ4+mTJmi999/X82bNz/pOPa5/qqqqtS7d2/NnTtXktSzZ09t2bJF6enpGjt2rHcce10///znP/Xaa6/pjTfe0MUXX6y8vDxNnTpVMTExGjdunHcc+9zwGmJPaxrfEPt+xr/t0qZNGwUFBVWrtP3791erQvyy++67T0uXLtXKlSsVGxvrPR4VFSVJte5zVFSUKioqVFxcfNIxZ7sNGzZo//796tWrl4KDgxUcHKycnBz95S9/UXBwsHef2Of6i46OVteuXX2OXXTRRdq9e7ck/k03lIceekgzZszQLbfcom7duiklJUUPPPCAUlNTJbHPp0ND7WlUVJT27dtXbf4DBw7Ue9/P+PgIDQ1Vr169tHz5cp/jy5cvV79+/RppVYHHGKPJkydr8eLFWrFiheLj433Ox8fHKyoqymefKyoqlJOT493nXr16KSQkxGdMYWGhPv30U/4u/ufKK69Ufn6+8vLyvF+9e/fWmDFjlJeXpwsuuIB9biD9+/ev9nHxbdu2qUOHDpL4N91QDh8+rGbNfJ9qgoKCvB+1ZZ8bXkPtaXJysjwej9atW+cds3btWnk8nvrve70uVw0QJz5q+7e//c1s3brVTJ061bRs2dLs2rWrsZcWMO655x7jdrtNdna2KSws9H4dPnzYOyYtLc243W6zePFik5+fb2699dYaP9oVGxtrPvjgA7Nx40YzePDgs/rjcqfip592MYZ9bijr1q0zwcHBZs6cOWb79u3m9ddfN2FhYea1117zjmGv62/cuHGmXbt23o/aLl682LRp08Y8/PDD3jHss/9KS0vNpk2bzKZNm4wk89xzz5lNmzZ5f4REQ+3p8OHDTVJSklm9erVZvXq16datGx+19ceLL75oOnToYEJDQ82vfvUr70dEcWok1fiVkZHhHVNVVWVmz55toqKijNPpNJdffrnJz8/3meeHH34wkydPNhEREaZFixbmmmuuMbt377b8aALLz+ODfW447777rklMTDROp9MkJCSY+fPn+5xnr+uvpKTETJkyxbRv3940b97cXHDBBeaRRx4x5eXl3jHss/9WrlxZ4/8mjxs3zhjTcHv67bffmjFjxpjw8HATHh5uxowZY4qLi+u9focxxtTvtRMAAIBTd8Zf8wEAAJoW4gMAAFhFfAAAAKuIDwAAYBXxAQAArCI+AACAVcQHAACwivgAAABWER8AAMAq4gMAAFhFfAAAAKuIDwAAYNX/BxJdbIcUM2/sAAAAAElFTkSuQmCC",
"text/plain": [
"<matplotlib.image.AxesImage at 0x7fa3db2b7010>"
]
},
"execution_count": 69,
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"species = \"O\"\n",
"iterations = 1000\n",
"cell_offset = 60\n",
"y_design = []\n",
"y_results = []\n",
"y_differences = []\n",
"\n",
"\n",
"df_design_transformed_scaled = scaler_X.transform(FuncTransform(func_dict_in, func_dict_out).fit_transform(df_design))\n",
"df_results_transformed_scaled = scaler_X.transform(FuncTransform(func_dict_in, func_dict_out).fit_transform(df_results))\n",
"\n",
"for i in range(0,iterations):\n",
" idx = i*50*50 + cell_offset -1\n",
" y_design.append(df_design_transformed_scaled.iloc[idx, :])\n",
" y_results.append(df_results_transformed_scaled.iloc[idx,:])\n",
" \n",
"y_design = pd.DataFrame(y_design)\n",
"y_results = pd.DataFrame(y_results)\n",
"# plt.plot(np.arange(0,iterations), y_design[species], label = \"Design\")\n",
"plt.plot(np.arange(0,iterations), y_results[species], label = \"Results\")\n",
"\n",
"prediction = model_simple.predict(y_design.iloc[:, y_design.columns != \"Class\"])\n",
"prediction = pd.DataFrame(prediction, columns = y_results.columns)\n",
"\n",
"y_results_back = FuncTransform(func_dict_in, func_dict_out).inverse_transform(pd.DataFrame(scaler_X.inverse_transform(y_results), columns=df_results.columns))\n",
"prediction_back = FuncTransform(func_dict_in, func_dict_out).inverse_transform(pd.DataFrame(scaler_X.inverse_transform(prediction), columns=df_results.columns))\n",
"\n",
"plt.plot(np.arange(0,iterations), prediction[species], label = \"Prediction\")\n",
"plt.xlabel('Iteration')\n",
"plt.ylabel(species)\n",
"plt.title(species+' Concentration over Iterations')\n",
"plt.show()\n",
"\n",
"plt.plot(np.arange(0,iterations), y_results_back[species], label = \"Results\")\n",
"plt.plot(np.arange(0,iterations), prediction_back[species], label = \"Prediction\")\n",
"\n",
"plt.show()\n",
"\n",
"timestep = 1000\n",
"plt.imshow(np.array(df_results[\"Barite\"][(timestep*2500):(timestep*2500+2500)]).reshape(50,50), origin='lower')"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>H</th>\n",
" <th>O</th>\n",
" <th>Charge</th>\n",
" <th>H_0_</th>\n",
" <th>O_0_</th>\n",
" <th>Ba</th>\n",
" <th>Cl</th>\n",
" <th>S_2_</th>\n",
" <th>S_6_</th>\n",
" <th>Sr</th>\n",
" <th>Barite</th>\n",
" <th>Celestite</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>111.012434</td>\n",
" <td>55.508700</td>\n",
" <td>-1.216415e-09</td>\n",
" <td>0.0</td>\n",
" <td>2.217711e-13</td>\n",
" <td>4.495355e-07</td>\n",
" <td>1.532249e-12</td>\n",
" <td>0.0</td>\n",
" <td>0.000621</td>\n",
" <td>0.000620</td>\n",
" <td>0.001000</td>\n",
" <td>1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>111.012434</td>\n",
" <td>55.508700</td>\n",
" <td>-1.222504e-09</td>\n",
" <td>0.0</td>\n",
" <td>1.312902e-12</td>\n",
" <td>4.500359e-07</td>\n",
" <td>1.044603e-08</td>\n",
" <td>0.0</td>\n",
" <td>0.000621</td>\n",
" <td>0.000620</td>\n",
" <td>0.001000</td>\n",
" <td>1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>111.012434</td>\n",
" <td>55.508699</td>\n",
" <td>-1.220407e-09</td>\n",
" <td>0.0</td>\n",
" <td>1.614098e-12</td>\n",
" <td>4.500563e-07</td>\n",
" <td>4.907802e-07</td>\n",
" <td>0.0</td>\n",
" <td>0.000621</td>\n",
" <td>0.000620</td>\n",
" <td>0.001000</td>\n",
" <td>1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>111.012434</td>\n",
" <td>55.508695</td>\n",
" <td>-1.216831e-09</td>\n",
" <td>0.0</td>\n",
" <td>2.293739e-12</td>\n",
" <td>4.504482e-07</td>\n",
" <td>4.772370e-06</td>\n",
" <td>0.0</td>\n",
" <td>0.000620</td>\n",
" <td>0.000622</td>\n",
" <td>0.001000</td>\n",
" <td>1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>111.012434</td>\n",
" <td>55.508679</td>\n",
" <td>-1.216842e-09</td>\n",
" <td>0.0</td>\n",
" <td>2.641545e-12</td>\n",
" <td>4.534070e-07</td>\n",
" <td>2.200220e-05</td>\n",
" <td>0.0</td>\n",
" <td>0.000616</td>\n",
" <td>0.000626</td>\n",
" <td>0.001000</td>\n",
" <td>1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>995</th>\n",
" <td>111.012434</td>\n",
" <td>55.506410</td>\n",
" <td>2.170844e-08</td>\n",
" <td>0.0</td>\n",
" <td>4.777415e-11</td>\n",
" <td>1.550089e-04</td>\n",
" <td>9.784038e-02</td>\n",
" <td>0.0</td>\n",
" <td>0.000048</td>\n",
" <td>0.048814</td>\n",
" <td>0.149476</td>\n",
" <td>0.846327</td>\n",
" </tr>\n",
" <tr>\n",
" <th>996</th>\n",
" <td>111.012434</td>\n",
" <td>55.506410</td>\n",
" <td>2.166504e-08</td>\n",
" <td>0.0</td>\n",
" <td>4.808612e-11</td>\n",
" <td>1.559012e-04</td>\n",
" <td>9.785750e-02</td>\n",
" <td>0.0</td>\n",
" <td>0.000048</td>\n",
" <td>0.048821</td>\n",
" <td>0.151708</td>\n",
" <td>0.844093</td>\n",
" </tr>\n",
" <tr>\n",
" <th>997</th>\n",
" <td>111.012434</td>\n",
" <td>55.506409</td>\n",
" <td>2.162167e-08</td>\n",
" <td>0.0</td>\n",
" <td>4.811205e-11</td>\n",
" <td>1.567226e-04</td>\n",
" <td>9.787459e-02</td>\n",
" <td>0.0</td>\n",
" <td>0.000048</td>\n",
" <td>0.048829</td>\n",
" <td>0.153945</td>\n",
" <td>0.841856</td>\n",
" </tr>\n",
" <tr>\n",
" <th>998</th>\n",
" <td>111.012434</td>\n",
" <td>55.506409</td>\n",
" <td>2.157995e-08</td>\n",
" <td>0.0</td>\n",
" <td>4.815004e-11</td>\n",
" <td>1.574812e-04</td>\n",
" <td>9.789167e-02</td>\n",
" <td>0.0</td>\n",
" <td>0.000048</td>\n",
" <td>0.048836</td>\n",
" <td>0.156185</td>\n",
" <td>0.839614</td>\n",
" </tr>\n",
" <tr>\n",
" <th>999</th>\n",
" <td>111.012434</td>\n",
" <td>55.506409</td>\n",
" <td>2.153938e-08</td>\n",
" <td>0.0</td>\n",
" <td>4.815067e-11</td>\n",
" <td>1.581835e-04</td>\n",
" <td>9.790872e-02</td>\n",
" <td>0.0</td>\n",
" <td>0.000048</td>\n",
" <td>0.048844</td>\n",
" <td>0.158428</td>\n",
" <td>0.837370</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1000 rows × 12 columns</p>\n",
"</div>"
],
"text/plain": [
" H O Charge H_0_ O_0_ Ba \\\n",
"0 111.012434 55.508700 -1.216415e-09 0.0 2.217711e-13 4.495355e-07 \n",
"1 111.012434 55.508700 -1.222504e-09 0.0 1.312902e-12 4.500359e-07 \n",
"2 111.012434 55.508699 -1.220407e-09 0.0 1.614098e-12 4.500563e-07 \n",
"3 111.012434 55.508695 -1.216831e-09 0.0 2.293739e-12 4.504482e-07 \n",
"4 111.012434 55.508679 -1.216842e-09 0.0 2.641545e-12 4.534070e-07 \n",
".. ... ... ... ... ... ... \n",
"995 111.012434 55.506410 2.170844e-08 0.0 4.777415e-11 1.550089e-04 \n",
"996 111.012434 55.506410 2.166504e-08 0.0 4.808612e-11 1.559012e-04 \n",
"997 111.012434 55.506409 2.162167e-08 0.0 4.811205e-11 1.567226e-04 \n",
"998 111.012434 55.506409 2.157995e-08 0.0 4.815004e-11 1.574812e-04 \n",
"999 111.012434 55.506409 2.153938e-08 0.0 4.815067e-11 1.581835e-04 \n",
"\n",
" Cl S_2_ S_6_ Sr Barite Celestite \n",
"0 1.532249e-12 0.0 0.000621 0.000620 0.001000 1.000000 \n",
"1 1.044603e-08 0.0 0.000621 0.000620 0.001000 1.000000 \n",
"2 4.907802e-07 0.0 0.000621 0.000620 0.001000 1.000000 \n",
"3 4.772370e-06 0.0 0.000620 0.000622 0.001000 1.000000 \n",
"4 2.200220e-05 0.0 0.000616 0.000626 0.001000 1.000000 \n",
".. ... ... ... ... ... ... \n",
"995 9.784038e-02 0.0 0.000048 0.048814 0.149476 0.846327 \n",
"996 9.785750e-02 0.0 0.000048 0.048821 0.151708 0.844093 \n",
"997 9.787459e-02 0.0 0.000048 0.048829 0.153945 0.841856 \n",
"998 9.789167e-02 0.0 0.000048 0.048836 0.156185 0.839614 \n",
"999 9.790872e-02 0.0 0.000048 0.048844 0.158428 0.837370 \n",
"\n",
"[1000 rows x 12 columns]"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"FuncTransform(func_dict_in, func_dict_out).inverse_transform(pd.DataFrame(scaler_X.inverse_transform(y_results), columns=df_results.columns))"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>H</th>\n",
" <th>O</th>\n",
" <th>Charge</th>\n",
" <th>H_0_</th>\n",
" <th>O_0_</th>\n",
" <th>Ba</th>\n",
" <th>Cl</th>\n",
" <th>S_2_</th>\n",
" <th>S_6_</th>\n",
" <th>Sr</th>\n",
" <th>Barite</th>\n",
" <th>Celestite</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>111.012428</td>\n",
" <td>55.508682</td>\n",
" <td>-1.193141e-09</td>\n",
" <td>3.397941e-15</td>\n",
" <td>2.128961e-13</td>\n",
" <td>-0.000012</td>\n",
" <td>0.000021</td>\n",
" <td>-1.191799e-17</td>\n",
" <td>0.000620</td>\n",
" <td>0.000630</td>\n",
" <td>0.000985</td>\n",
" <td>1.000231</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>111.012428</td>\n",
" <td>55.508682</td>\n",
" <td>-1.202185e-09</td>\n",
" <td>2.918474e-15</td>\n",
" <td>1.019832e-12</td>\n",
" <td>-0.000016</td>\n",
" <td>0.000008</td>\n",
" <td>-7.920033e-18</td>\n",
" <td>0.000620</td>\n",
" <td>0.000630</td>\n",
" <td>0.000946</td>\n",
" <td>1.000121</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>111.012428</td>\n",
" <td>55.508682</td>\n",
" <td>-1.203471e-09</td>\n",
" <td>1.785468e-15</td>\n",
" <td>2.398433e-12</td>\n",
" <td>-0.000015</td>\n",
" <td>-0.000013</td>\n",
" <td>-9.158671e-18</td>\n",
" <td>0.000620</td>\n",
" <td>0.000627</td>\n",
" <td>0.000913</td>\n",
" <td>0.999977</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>111.012428</td>\n",
" <td>55.508682</td>\n",
" <td>-1.199235e-09</td>\n",
" <td>1.746077e-15</td>\n",
" <td>2.357316e-12</td>\n",
" <td>-0.000016</td>\n",
" <td>-0.000011</td>\n",
" <td>-9.246642e-18</td>\n",
" <td>0.000619</td>\n",
" <td>0.000629</td>\n",
" <td>0.000916</td>\n",
" <td>0.999936</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>111.012428</td>\n",
" <td>55.508682</td>\n",
" <td>-1.197043e-09</td>\n",
" <td>1.533956e-15</td>\n",
" <td>2.670053e-12</td>\n",
" <td>-0.000019</td>\n",
" <td>0.000003</td>\n",
" <td>-9.144569e-18</td>\n",
" <td>0.000615</td>\n",
" <td>0.000635</td>\n",
" <td>0.000943</td>\n",
" <td>0.999769</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>995</th>\n",
" <td>111.012428</td>\n",
" <td>55.506416</td>\n",
" <td>2.189995e-08</td>\n",
" <td>-3.792428e-15</td>\n",
" <td>4.785985e-11</td>\n",
" <td>0.000279</td>\n",
" <td>0.097642</td>\n",
" <td>1.182626e-16</td>\n",
" <td>0.000051</td>\n",
" <td>0.048738</td>\n",
" <td>0.149016</td>\n",
" <td>0.844585</td>\n",
" </tr>\n",
" <tr>\n",
" <th>996</th>\n",
" <td>111.012428</td>\n",
" <td>55.506416</td>\n",
" <td>2.188781e-08</td>\n",
" <td>-3.910681e-15</td>\n",
" <td>4.822730e-11</td>\n",
" <td>0.000279</td>\n",
" <td>0.097665</td>\n",
" <td>1.169111e-16</td>\n",
" <td>0.000051</td>\n",
" <td>0.048751</td>\n",
" <td>0.151241</td>\n",
" <td>0.842272</td>\n",
" </tr>\n",
" <tr>\n",
" <th>997</th>\n",
" <td>111.012428</td>\n",
" <td>55.506416</td>\n",
" <td>2.184875e-08</td>\n",
" <td>-3.749360e-15</td>\n",
" <td>4.824932e-11</td>\n",
" <td>0.000279</td>\n",
" <td>0.097687</td>\n",
" <td>1.146696e-16</td>\n",
" <td>0.000051</td>\n",
" <td>0.048761</td>\n",
" <td>0.153477</td>\n",
" <td>0.840015</td>\n",
" </tr>\n",
" <tr>\n",
" <th>998</th>\n",
" <td>111.012428</td>\n",
" <td>55.506416</td>\n",
" <td>2.180415e-08</td>\n",
" <td>-3.500642e-15</td>\n",
" <td>4.816323e-11</td>\n",
" <td>0.000279</td>\n",
" <td>0.097710</td>\n",
" <td>1.119656e-16</td>\n",
" <td>0.000051</td>\n",
" <td>0.048770</td>\n",
" <td>0.155720</td>\n",
" <td>0.837773</td>\n",
" </tr>\n",
" <tr>\n",
" <th>999</th>\n",
" <td>111.012428</td>\n",
" <td>55.506416</td>\n",
" <td>2.177183e-08</td>\n",
" <td>-3.375572e-15</td>\n",
" <td>4.822685e-11</td>\n",
" <td>0.000279</td>\n",
" <td>0.097732</td>\n",
" <td>1.095921e-16</td>\n",
" <td>0.000050</td>\n",
" <td>0.048780</td>\n",
" <td>0.157962</td>\n",
" <td>0.835504</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1000 rows × 12 columns</p>\n",
"</div>"
],
"text/plain": [
" H O Charge H_0_ O_0_ \\\n",
"0 111.012428 55.508682 -1.193141e-09 3.397941e-15 2.128961e-13 \n",
"1 111.012428 55.508682 -1.202185e-09 2.918474e-15 1.019832e-12 \n",
"2 111.012428 55.508682 -1.203471e-09 1.785468e-15 2.398433e-12 \n",
"3 111.012428 55.508682 -1.199235e-09 1.746077e-15 2.357316e-12 \n",
"4 111.012428 55.508682 -1.197043e-09 1.533956e-15 2.670053e-12 \n",
".. ... ... ... ... ... \n",
"995 111.012428 55.506416 2.189995e-08 -3.792428e-15 4.785985e-11 \n",
"996 111.012428 55.506416 2.188781e-08 -3.910681e-15 4.822730e-11 \n",
"997 111.012428 55.506416 2.184875e-08 -3.749360e-15 4.824932e-11 \n",
"998 111.012428 55.506416 2.180415e-08 -3.500642e-15 4.816323e-11 \n",
"999 111.012428 55.506416 2.177183e-08 -3.375572e-15 4.822685e-11 \n",
"\n",
" Ba Cl S_2_ S_6_ Sr Barite Celestite \n",
"0 -0.000012 0.000021 -1.191799e-17 0.000620 0.000630 0.000985 1.000231 \n",
"1 -0.000016 0.000008 -7.920033e-18 0.000620 0.000630 0.000946 1.000121 \n",
"2 -0.000015 -0.000013 -9.158671e-18 0.000620 0.000627 0.000913 0.999977 \n",
"3 -0.000016 -0.000011 -9.246642e-18 0.000619 0.000629 0.000916 0.999936 \n",
"4 -0.000019 0.000003 -9.144569e-18 0.000615 0.000635 0.000943 0.999769 \n",
".. ... ... ... ... ... ... ... \n",
"995 0.000279 0.097642 1.182626e-16 0.000051 0.048738 0.149016 0.844585 \n",
"996 0.000279 0.097665 1.169111e-16 0.000051 0.048751 0.151241 0.842272 \n",
"997 0.000279 0.097687 1.146696e-16 0.000051 0.048761 0.153477 0.840015 \n",
"998 0.000279 0.097710 1.119656e-16 0.000051 0.048770 0.155720 0.837773 \n",
"999 0.000279 0.097732 1.095921e-16 0.000050 0.048780 0.157962 0.835504 \n",
"\n",
"[1000 rows x 12 columns]"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"FuncTransform(func_dict_in, func_dict_out).inverse_transform(pd.DataFrame(scaler_X.inverse_transform(prediction), columns=prediction.columns))"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(history.history[\"loss\"], \"o-\", label = \"Training Loss\")\n",
"plt.xlabel(\"Epoch\")\n",
"# plt.yscale('log')\n",
"plt.ylabel(\"Loss (Huber)\")\n",
"plt.grid('on')\n",
"\n",
"plt.savefig(\"loss_all.png\", dpi=300)\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(history.history[\"loss\"][1:], \"o-\", label = \"Training Loss\")\n",
"plt.xlabel(\"Epoch\")\n",
"# plt.yscale('log')\n",
"plt.ylabel(\"Loss (Huber)\")\n",
"plt.grid('on')\n",
"plt.savefig(\"loss_1_to_end.png\", dpi=300)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test the model"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m 20/7821\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m44s\u001b[0m 6ms/step - loss: 5.1914e-06"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m7821/7821\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m29s\u001b[0m 4ms/step - loss: 1.0395e-06\n"
]
},
{
"data": {
"text/plain": [
"9.875676596493577e-07"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# test on all test data\n",
"model_large.evaluate(X_test.iloc[:,X_test.columns != \"Class\"], y_test.iloc[:, y_test.columns != \"Class\"])"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m7727/7727\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m29s\u001b[0m 4ms/step - loss: 5.4493e-07\n"
]
},
{
"data": {
"text/plain": [
"5.075861508885282e-07"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# test on non-reactive data\n",
"model_large.evaluate(X_test[X_test['Class'] == 0].iloc[:,:-1], y_test[X_test['Class'] == 0].iloc[:,:-1])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m94/94\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step - loss: 4.0710e-05\n"
]
}
],
"source": [
"mass_balance = mass_balance(model_simple, X_test, scaler_X, func_dict_in, func_dict_out)"
]
},
{
"cell_type": "code",
"execution_count": 116,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4.047931361128576e-05"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# test on reactive data\n",
"model_large.evaluate(X_test[X_test['Class'] == 1].iloc[:,:-1], y_test[X_test['Class'] == 1].iloc[:, :-1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Save the model"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"# Save the model\n",
"model.save(\"Barite_50_Model_additional_species.keras\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Legacy Code"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"def log_scale(df_design, df_result, func_dict):\n",
" \n",
" df_design = df_design.copy()\n",
" df_result = df_result.copy()\n",
" \n",
" for key in df_design.keys():\n",
" if key != \"Class\":\n",
" df_design[key] = np.vectorize(func_dict[key])(df_design[key])\n",
" df_result[key] = np.vectorize(func_dict[key])(df_result[key])\n",
" \n",
" return df_design, df_result\n",
"\n",
"# Get minimum and maximum values for each column\n",
"def get_min_max(df_design, df_result):\n",
" \n",
" min_vals_des = df_design.min()\n",
" max_vals_des = df_design.max()\n",
" \n",
" min_vals_res = df_result.min()\n",
" max_vals_res = df_result.max()\n",
"\n",
" # minimum of input and output data to get global minimum/maximum\n",
" data_min = np.minimum(min_vals_des, min_vals_res).to_dict()\n",
" data_max = np.maximum(max_vals_des, max_vals_res).to_dict()\n",
"\n",
" return data_min, data_max\n",
"\n",
"df_design_log, df_results_log = log_scale(df_design, df_results, func_dict_in)\n",
"data_min_log, data_max_log = get_min_max(df_design_log, df_design_log)\n",
"\n",
"train_min_log, train_max_log = get_min_max(X_train_log, y_train_log)\n",
"test_min_log, test_max_log = get_min_max(X_test_log, y_test_log)\n",
"\n",
"X_train_preprocess = preprocess(X_train_log, func_dict_in, train_min_log, train_max_log)\n",
"y_train_preprocess = preprocess(y_train_log, func_dict_in, train_min_log, train_max_log)\n",
"\n",
"X_test_preprocess = preprocess(X_test_log, func_dict_in, test_min_log, test_max_log)\n",
"y_test_preprocess = preprocess(y_test_log, func_dict_in, test_min_log, test_max_log)\n",
"\n",
"X_train_log, y_train_log = log_scale(X_train, y_train, func_dict_in)\n",
"X_test_log, y_test_log = log_scale(X_test, y_test, func_dict_in)\n",
"\n",
"\n",
"def preprocess(data, func_dict, data_min, data_max):\n",
" data = data.copy()\n",
" for key in data.keys():\n",
" if key != \"Class\":\n",
" data[key] = (data[key] - data_min[key]) / (data_max[key] - data_min[key])\n",
"\n",
" return data\n",
"\n",
"def postprocess(data, func_dict, data_min, data_max):\n",
" data = data.copy()\n",
" for key in data.keys():\n",
" if key != \"Class\":\n",
" data[key] = data[key] * (data_max[key] - data_min[key]) + data_min[key]\n",
" data[key] = np.vectorize(func_dict[key])(data[key])\n",
" return data\n",
"\n",
"X_train, X_val, y_train, y_val = sk.train_test_split(X_train_preprocess, y_train_preprocess, test_size = 0.1)\n",
"\n",
"pp_design = preprocess(df_design_log, func_dict_in, data_min_log, data_max_log)\n",
"pp_results = preprocess(df_results_log, func_dict_in, data_min_log, data_max_log)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ai",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}