mirror of
https://git.gfz-potsdam.de/naaice/model-training.git
synced 2025-12-15 19:58:22 +01:00
1018 lines
123 KiB
Plaintext
1018 lines
123 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## General Information"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"This notebook is used to train a simple neural network model to predict the chemistry in the barite benchmark (50x50 grid). The training data is stored in the repository using **git large file storage** and can be downloaded after the installation of git lfs using the `git lfs pull` command.\n",
|
|
"\n",
|
|
"It is then recommended to create a Python environment using miniconda. The necessary dependencies are contained in `environment.yml` and can be installed using `conda env create -f environment.yml`.\n",
|
|
"\n",
|
|
"The data set is divided into a design and result part and consists of the iterations of a reference simulation. The design part of the data set contains the chemical concentrations at time $t$ and the result part at time $t+1$, which are to be learned by the model."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Setup Libraries"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2025-01-15 16:24:49.275664: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
|
|
"2025-01-15 16:24:49.404820: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
|
|
"To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Running Keras in version 3.6.0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import keras\n",
|
|
"print(\"Running Keras in version {}\".format(keras.__version__))\n",
|
|
"\n",
|
|
"import h5py\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"import time\n",
|
|
"import sklearn.model_selection as sk\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from sklearn.cluster import KMeans\n",
|
|
"from imblearn.over_sampling import SMOTE\n",
|
|
"from imblearn.under_sampling import RandomUnderSampler\n",
|
|
"from imblearn.over_sampling import RandomOverSampler\n",
|
|
"from collections import Counter\n",
|
|
"import os"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Define parameters"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 141,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"dtype = \"float32\"\n",
|
|
"activation = \"relu\"\n",
|
|
"\n",
|
|
"lr = 0.001\n",
|
|
"batch_size = 512\n",
|
|
"epochs = 50 # default 400 epochs\n",
|
|
"\n",
|
|
"lr_schedule = keras.optimizers.schedules.ExponentialDecay(\n",
|
|
" initial_learning_rate=lr,\n",
|
|
" decay_steps=2000,\n",
|
|
" decay_rate=0.9,\n",
|
|
" staircase=True\n",
|
|
")\n",
|
|
"\n",
|
|
"optimizer_simple = keras.optimizers.Adam(learning_rate=lr_schedule)\n",
|
|
"optimizer_large = keras.optimizers.Adam(learning_rate=lr_schedule)\n",
|
|
"\n",
|
|
"loss = keras.losses.Huber()\n",
|
|
"\n",
|
|
"sample_fraction = 0.8"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Setup the model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 142,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential_5\"</span>\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\u001b[1mModel: \"sequential_5\"\u001b[0m\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
|
|
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n",
|
|
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
|
|
"│ dense_17 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,664</span> │\n",
|
|
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
|
|
"│ dense_18 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">16,512</span> │\n",
|
|
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
|
|
"│ dense_19 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,548</span> │\n",
|
|
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
|
|
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
|
|
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
|
|
"│ dense_17 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m1,664\u001b[0m │\n",
|
|
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
|
|
"│ dense_18 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m16,512\u001b[0m │\n",
|
|
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
|
|
"│ dense_19 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m1,548\u001b[0m │\n",
|
|
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">19,724</span> (77.05 KB)\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m19,724\u001b[0m (77.05 KB)\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">19,724</span> (77.05 KB)\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m19,724\u001b[0m (77.05 KB)\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model_simple = keras.Sequential(\n",
|
|
" [\n",
|
|
" keras.Input(shape = (12,), dtype = \"float32\"),\n",
|
|
" keras.layers.Dense(units = 128, activation = \"relu\", dtype = \"float32\"),\n",
|
|
" keras.layers.Dense(units = 128, activation = \"relu\", dtype = \"float32\"),\n",
|
|
" keras.layers.Dense(units = 12, dtype = \"float32\")\n",
|
|
" ]\n",
|
|
")\n",
|
|
"\n",
|
|
"model_simple.compile(optimizer=optimizer_simple, loss = loss)\n",
|
|
"model_simple.summary()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 143,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential_6\"</span>\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\u001b[1mModel: \"sequential_6\"\u001b[0m\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
|
|
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n",
|
|
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
|
|
"│ dense_20 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">512</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">6,656</span> │\n",
|
|
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
|
|
"│ dense_21 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1024</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">525,312</span> │\n",
|
|
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
|
|
"│ dense_22 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">512</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">524,800</span> │\n",
|
|
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
|
|
"│ dense_23 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">6,156</span> │\n",
|
|
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
|
|
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
|
|
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
|
|
"│ dense_20 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m6,656\u001b[0m │\n",
|
|
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
|
|
"│ dense_21 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1024\u001b[0m) │ \u001b[38;5;34m525,312\u001b[0m │\n",
|
|
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
|
|
"│ dense_22 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m524,800\u001b[0m │\n",
|
|
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
|
|
"│ dense_23 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m6,156\u001b[0m │\n",
|
|
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,062,924</span> (4.05 MB)\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,062,924\u001b[0m (4.05 MB)\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,062,924</span> (4.05 MB)\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,062,924\u001b[0m (4.05 MB)\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
|
|
"</pre>\n"
|
|
],
|
|
"text/plain": [
|
|
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model_large = keras.Sequential(\n",
|
|
" [keras.layers.Input(shape=(12,), dtype=dtype),\n",
|
|
" keras.layers.Dense(512, activation='relu', dtype=dtype),\n",
|
|
" keras.layers.Dense(1024, activation='relu', dtype=dtype),\n",
|
|
" keras.layers.Dense(512, activation='relu', dtype=dtype),\n",
|
|
" keras.layers.Dense(12, dtype=dtype)\n",
|
|
" ])\n",
|
|
"\n",
|
|
"model_large.compile(optimizer=optimizer_large, loss = loss)\n",
|
|
"model_large.summary()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Define some functions and helper classes"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def Safelog(val):\n",
|
|
" # get range of vector\n",
|
|
" if val > 0:\n",
|
|
" return np.log10(val)\n",
|
|
" elif val < 0:\n",
|
|
" return -np.log10(-val)\n",
|
|
" else:\n",
|
|
" return 0\n",
|
|
"\n",
|
|
"def Safeexp(val):\n",
|
|
" if val > 0:\n",
|
|
" return -10 ** -val\n",
|
|
" elif val < 0:\n",
|
|
" return 10 ** val\n",
|
|
" else:\n",
|
|
" return 0\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# ? Why does the charge is using another logarithm than the other species\n",
|
|
"\n",
|
|
"func_dict_in = {\n",
|
|
" \"H\" : np.log1p,\n",
|
|
" \"O\" : np.log1p,\n",
|
|
" \"Charge\" : Safelog,\n",
|
|
" \"H_0_\" : np.log1p,\n",
|
|
" \"O_0_\" : np.log1p,\n",
|
|
" \"Ba\" : np.log1p,\n",
|
|
" \"Cl\" : np.log1p,\n",
|
|
" \"S_2_\" : np.log1p,\n",
|
|
" \"S_6_\" : np.log1p,\n",
|
|
" \"Sr\" : np.log1p,\n",
|
|
" \"Barite\" : np.log1p,\n",
|
|
" \"Celestite\" : np.log1p,\n",
|
|
"}\n",
|
|
"\n",
|
|
"func_dict_out = {\n",
|
|
" \"H\" : np.expm1,\n",
|
|
" \"O\" : np.expm1,\n",
|
|
" \"Charge\" : Safeexp,\n",
|
|
" \"H_0_\" : np.expm1,\n",
|
|
" \"O_0_\" : np.expm1,\n",
|
|
" \"Ba\" : np.expm1,\n",
|
|
" \"Cl\" : np.expm1,\n",
|
|
" \"S_2_\" : np.expm1,\n",
|
|
" \"S_6_\" : np.expm1,\n",
|
|
" \"Sr\" : np.expm1,\n",
|
|
" \"Barite\" : np.expm1,\n",
|
|
" \"Celestite\" : np.expm1,\n",
|
|
"}\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Read data from `.h5` file and convert it to a `pandas.DataFrame`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 42,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# os.chdir('/mnt/beegfs/home/signer/projects/model-training')\n",
|
|
"data_file = h5py.File(\"Barite_50_Data_training.h5\")\n",
|
|
"\n",
|
|
"design = data_file[\"design\"]\n",
|
|
"results = data_file[\"result\"]\n",
|
|
"\n",
|
|
"df_design = pd.DataFrame(np.array(design[\"data\"]).transpose(), columns = design[\"names\"].asstr())\n",
|
|
"df_results = pd.DataFrame(np.array(results[\"data\"]).transpose(), columns = results[\"names\"].asstr())\n",
|
|
"\n",
|
|
"data_file.close()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Classify each cell with kmeans"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 43,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/signer/bin/miniconda3/envs/training/lib/python3.11/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (2). Possibly due to duplicate points in X.\n",
|
|
" return fit_method(estimator, *args, **kwargs)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# widget with slider for the index\n",
|
|
"\n",
|
|
"class_label_design = np.array([])\n",
|
|
"class_label_result = np.array([])\n",
|
|
"\n",
|
|
"\n",
|
|
"i = 1000\n",
|
|
"for i in range(0,1001):\n",
|
|
" field_design = np.array(df_design['Barite'][(i*2500):(i*2500+2500)]).reshape(50,50)\n",
|
|
" field_result = np.array(df_results['Barite'][(i*2500):(i*2500+2500)]).reshape(50,50)\n",
|
|
" \n",
|
|
" kmeans_design = KMeans(n_clusters=2, random_state=0).fit(field_design.reshape(-1,1))\n",
|
|
" kmeans_result = KMeans(n_clusters=2, random_state=0).fit(field_result.reshape(-1,1))\n",
|
|
" \n",
|
|
" class_label_design = np.append(class_label_design.astype(int), kmeans_design.labels_)\n",
|
|
" class_label_result = np.append(class_label_result.astype(int), kmeans_result.labels_)\n",
|
|
" \n",
|
|
"\n",
|
|
"\n",
|
|
"class_label_design = pd.DataFrame(class_label_design, columns = [\"Class\"])\n",
|
|
"class_label_result = pd.DataFrame(class_label_result, columns = [\"Class\"])\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 44,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"if(\"Class\" in df_design.columns and \"Class\" in df_results.columns):\n",
|
|
" print(\"Class column already exists\")\n",
|
|
"else:\n",
|
|
" df_design = pd.concat([df_design, class_label_design], axis=1)\n",
|
|
" df_results = pd.concat([df_results, class_label_design], axis=1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 45,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Amount class 0: 0.9879380619380619\n",
|
|
"Amount class 1: 0.012061938061938062\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"counter = Counter(df_design.iloc[:,-1])\n",
|
|
"print(\"Amount class 0:\", counter[0] / (counter[0] + counter[1]) )\n",
|
|
"print(\"Amount class 1:\", counter[1] / (counter[0] + counter[1]) )\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<matplotlib.contour.QuadContourSet at 0x7c17915fe310>"
|
|
]
|
|
},
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"i=800\n",
|
|
"\n",
|
|
"plt.imshow(np.array(df_results['Barite'][(i*2500):(i*2500+2500)]).reshape(50,50), interpolation='bicubic', origin='lower')\n",
|
|
"plt.contour(np.array(df_results['Class'][(i*2500):(i*2500+2500)]).reshape(50,50), levels=[0.1], colors='red', origin='lower')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Split into Training and Testing datsets"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 126,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"X_train, X_test, y_train, y_test = sk.train_test_split(df_design, df_results, test_size = 0.2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Perform Over and Under Sampling on dataset to balance classes"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 109,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def balancer(design, target, strategy, sample_fraction=0.5):\n",
|
|
" counter = Counter(design.iloc[:,-1])\n",
|
|
" print(\"Amount class 0 before:\", counter[0] / (counter[0] + counter[1]) )\n",
|
|
" print(\"Amount class 1 before:\", counter[1] / (counter[0] + counter[1]) )\n",
|
|
" \n",
|
|
" number_features = (df_design.columns != \"Class\").sum()\n",
|
|
" if(\"Class\" not in design.columns):\n",
|
|
" if(\"Class\" in target.columns):\n",
|
|
" classes = target['Class']\n",
|
|
" else:\n",
|
|
" raise(\"No class column found\")\n",
|
|
" else:\n",
|
|
" classes = design['Class']\n",
|
|
" df = pd.concat([design.loc[:,design.columns != \"Class\"], target.loc[:, design.columns != \"Class\"], classes], axis=1)\n",
|
|
" \n",
|
|
" if strategy == 'smote':\n",
|
|
" print(\"Using SMOTE strategy\")\n",
|
|
" smote = SMOTE(sampling_strategy=sample_fraction)\n",
|
|
" df_resampled, classes_resampled = smote.fit_resample(df.loc[:, df.columns != \"Class\"], df.loc[:, df.columns == \"Class\"])\n",
|
|
" \n",
|
|
" elif strategy == 'over':\n",
|
|
" print(\"Using Oversampling\")\n",
|
|
" over = RandomOverSampler()\n",
|
|
" df_resampled, classes_resampled = over.fit_resample(df.loc[:, df.columns != \"Class\"], df.loc[:, df.columns == \"Class\"])\n",
|
|
" \n",
|
|
" elif strategy == 'under':\n",
|
|
" print(\"Using Undersampling\")\n",
|
|
" under = RandomUnderSampler()\n",
|
|
" df_resampled, classes_resampled = under.fit_resample(df.loc[:, df.columns != \"Class\"], df.loc[:, df.columns == \"Class\"])\n",
|
|
"\n",
|
|
" counter = Counter(classes_resampled[\"Class\"])\n",
|
|
" print(\"Amount class 0 after:\", counter[0] / (counter[0] + counter[1]) )\n",
|
|
" print(\"Amount class 1 after:\", counter[1] / (counter[0] + counter[1]) )\n",
|
|
" \n",
|
|
" design_resampled = pd.concat([df_resampled.iloc[:,0:number_features], classes_resampled], axis=1)\n",
|
|
" target_resampled = pd.concat([df_resampled.iloc[:,number_features:], classes_resampled], axis=1)\n",
|
|
" \n",
|
|
" return design_resampled, target_resampled "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 127,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Amount class 0 before: 0.9878911088911089\n",
|
|
"Amount class 1 before: 0.012108891108891108\n",
|
|
"Using Oversampling\n",
|
|
"Amount class 0 after: 0.5\n",
|
|
"Amount class 1 after: 0.5\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"X_train, y_train = balancer(X_train, y_train, 'over')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Define Scaling and Normalization Functions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 87,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def log_scale(df_design, df_result, func_dict):\n",
|
|
" \n",
|
|
" df_design = df_design.copy()\n",
|
|
" df_result = df_result.copy()\n",
|
|
" \n",
|
|
" for key in df_design.keys():\n",
|
|
" if key != \"Class\":\n",
|
|
" df_design[key] = np.vectorize(func_dict[key])(df_design[key])\n",
|
|
" df_result[key] = np.vectorize(func_dict[key])(df_result[key])\n",
|
|
" \n",
|
|
" return df_design, df_result\n",
|
|
"\n",
|
|
"# Get minimum and maximum values for each column\n",
|
|
"def get_min_max(df_design, df_result):\n",
|
|
" \n",
|
|
" min_vals_des = df_design.min()\n",
|
|
" max_vals_des = df_design.max()\n",
|
|
" \n",
|
|
" min_vals_res = df_result.min()\n",
|
|
" max_vals_res = df_result.max()\n",
|
|
"\n",
|
|
" # minimum of input and output data to get global minimum/maximum\n",
|
|
" data_min = np.minimum(min_vals_des, min_vals_res).to_dict()\n",
|
|
" data_max = np.maximum(max_vals_des, max_vals_res).to_dict()\n",
|
|
"\n",
|
|
" return data_min, data_max\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 88,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"ename": "KeyboardInterrupt",
|
|
"evalue": "",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
|
"Cell \u001b[0;32mIn[88], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m df_design_log, df_results_log \u001b[38;5;241m=\u001b[39m log_scale(df_design, df_results, func_dict_in)\n\u001b[1;32m 2\u001b[0m data_min_log, data_max_log \u001b[38;5;241m=\u001b[39m get_min_max(df_design_log, df_results_log)\n",
|
|
"Cell \u001b[0;32mIn[87], line 8\u001b[0m, in \u001b[0;36mlog_scale\u001b[0;34m(df_design, df_result, func_dict)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m df_design\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mClass\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m----> 8\u001b[0m df_design[key] \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mvectorize(func_dict[key])(df_design[key])\n\u001b[1;32m 9\u001b[0m df_result[key] \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mvectorize(func_dict[key])(df_result[key])\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m df_design, df_result\n",
|
|
"File \u001b[0;32m~/bin/miniconda3/envs/training/lib/python3.11/site-packages/numpy/lib/function_base.py:2372\u001b[0m, in \u001b[0;36mvectorize.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 2369\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_init_stage_2(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 2370\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n\u001b[0;32m-> 2372\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_as_normal(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
|
|
"File \u001b[0;32m~/bin/miniconda3/envs/training/lib/python3.11/site-packages/numpy/lib/function_base.py:2365\u001b[0m, in \u001b[0;36mvectorize._call_as_normal\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 2362\u001b[0m vargs \u001b[38;5;241m=\u001b[39m [args[_i] \u001b[38;5;28;01mfor\u001b[39;00m _i \u001b[38;5;129;01min\u001b[39;00m inds]\n\u001b[1;32m 2363\u001b[0m vargs\u001b[38;5;241m.\u001b[39mextend([kwargs[_n] \u001b[38;5;28;01mfor\u001b[39;00m _n \u001b[38;5;129;01min\u001b[39;00m names])\n\u001b[0;32m-> 2365\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_vectorize_call(func\u001b[38;5;241m=\u001b[39mfunc, args\u001b[38;5;241m=\u001b[39mvargs)\n",
|
|
"File \u001b[0;32m~/bin/miniconda3/envs/training/lib/python3.11/site-packages/numpy/lib/function_base.py:2455\u001b[0m, in \u001b[0;36mvectorize._vectorize_call\u001b[0;34m(self, func, args)\u001b[0m\n\u001b[1;32m 2452\u001b[0m \u001b[38;5;66;03m# Convert args to object arrays first\u001b[39;00m\n\u001b[1;32m 2453\u001b[0m inputs \u001b[38;5;241m=\u001b[39m [asanyarray(a, dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mobject\u001b[39m) \u001b[38;5;28;01mfor\u001b[39;00m a \u001b[38;5;129;01min\u001b[39;00m args]\n\u001b[0;32m-> 2455\u001b[0m outputs \u001b[38;5;241m=\u001b[39m ufunc(\u001b[38;5;241m*\u001b[39minputs)\n\u001b[1;32m 2457\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ufunc\u001b[38;5;241m.\u001b[39mnout \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 2458\u001b[0m res \u001b[38;5;241m=\u001b[39m asanyarray(outputs, dtype\u001b[38;5;241m=\u001b[39motypes[\u001b[38;5;241m0\u001b[39m])\n",
|
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"df_design_log, df_results_log = log_scale(df_design, df_results, func_dict_in)\n",
|
|
"data_min_log, data_max_log = get_min_max(df_design_log, df_results_log)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 128,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"X_train_log, y_train_log = log_scale(X_train, y_train, func_dict_in)\n",
|
|
"X_test_log, y_test_log = log_scale(X_test, y_test, func_dict_in)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 129,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_min_log, train_max_log = get_min_max(X_train_log, y_train_log)\n",
|
|
"test_min_log, test_max_log = get_min_max(X_test_log, y_test_log)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 114,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def preprocess(data, func_dict, data_min, data_max):\n",
|
|
" data = data.copy()\n",
|
|
" for key in data.keys():\n",
|
|
" if key != \"Class\":\n",
|
|
" data[key] = (data[key] - data_min[key]) / (data_max[key] - data_min[key])\n",
|
|
"\n",
|
|
" return data\n",
|
|
"\n",
|
|
"def postprocess(data, func_dict, data_min, data_max):\n",
|
|
" data = data.copy()\n",
|
|
" for key in data.keys():\n",
|
|
" if key != \"Class\":\n",
|
|
" data[key] = data[key] * (data_max[key] - data_min[key]) + data_min[key]\n",
|
|
" data[key] = np.vectorize(func_dict[key])(data[key])\n",
|
|
" return data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Preprocess the data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"pp_design = preprocess(df_design_log, func_dict_in, data_min_log, data_max_log)\n",
|
|
"pp_results = preprocess(df_results_log, func_dict_in, data_min_log, data_max_log)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 130,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"X_train_preprocess = preprocess(X_train_log, func_dict_in, train_min_log, train_max_log)\n",
|
|
"y_train_preprocess = preprocess(y_train_log, func_dict_in, train_min_log, train_max_log)\n",
|
|
"\n",
|
|
"X_test_preprocess = preprocess(X_test_log, func_dict_in, test_min_log, test_max_log)\n",
|
|
"y_test_preprocess = preprocess(y_test_log, func_dict_in, test_min_log, test_max_log)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Sample the data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 131,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"X_train, X_val, y_train, y_val = sk.train_test_split(X_train_preprocess, y_train_preprocess, test_size = 0.1)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Custom Loss function"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def custom_loss_H20(df_design_log, df_result_log, data_min_log, data_max_log, func_dict_out, postprocess):\n",
|
|
" df_result = postprocess(df_result_log, func_dict_out, data_min_log, data_max_log) \n",
|
|
" return keras.losses.Huber + np.sum(((df_result['H'] / df_result['O']) - 2)**2)\n",
|
|
"\n",
|
|
"def loss_wrapper(data_min_log, data_max_log, func_dict_out, postprocess):\n",
|
|
" def loss(df_design_log, df_result_log):\n",
|
|
" return custom_loss_H20(df_design_log, df_result_log, data_min_log, data_max_log, func_dict_out, postprocess)\n",
|
|
" return loss"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Train the model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 144,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch 1/5\n",
|
|
"\u001b[1m6954/6954\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 2ms/step - loss: 0.0014 - val_loss: 1.9722e-05\n",
|
|
"Epoch 2/5\n",
|
|
"\u001b[1m6954/6954\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 2ms/step - loss: 1.8605e-05 - val_loss: 1.6460e-05\n",
|
|
"Epoch 3/5\n",
|
|
"\u001b[1m6954/6954\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 2ms/step - loss: 1.7344e-05 - val_loss: 1.8609e-05\n",
|
|
"Epoch 4/5\n",
|
|
"\u001b[1m6954/6954\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 2ms/step - loss: 1.6938e-05 - val_loss: 1.6669e-05\n",
|
|
"Epoch 5/5\n",
|
|
"\u001b[1m6954/6954\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 2ms/step - loss: 1.6373e-05 - val_loss: 1.5985e-05\n",
|
|
"Training took 63.22352385520935 seconds\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# measure time\n",
|
|
"start = time.time()\n",
|
|
"\n",
|
|
"history = model_simple.fit(X_train.iloc[:, :-1], \n",
|
|
" y_train.iloc[:, :-1], \n",
|
|
" batch_size = batch_size, \n",
|
|
" epochs = 5, \n",
|
|
" validation_data = (X_val.iloc[:,:-1], y_val.iloc[:, :-1])\n",
|
|
")\n",
|
|
"\n",
|
|
"end = time.time()\n",
|
|
"\n",
|
|
"print(\"Training took {} seconds\".format(end - start))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 145,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"plt.plot(history.history[\"loss\"], \"o-\", label = \"Training Loss\")\n",
|
|
"plt.xlabel(\"Epoch\")\n",
|
|
"# plt.yscale('log')\n",
|
|
"plt.ylabel(\"Loss (Huber)\")\n",
|
|
"plt.grid('on')\n",
|
|
"\n",
|
|
"plt.savefig(\"loss_all.png\", dpi=300)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 146,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"plt.plot(history.history[\"loss\"][1:], \"o-\", label = \"Training Loss\")\n",
|
|
"plt.xlabel(\"Epoch\")\n",
|
|
"# plt.yscale('log')\n",
|
|
"plt.ylabel(\"Loss (Huber)\")\n",
|
|
"plt.grid('on')\n",
|
|
"plt.savefig(\"loss_1_to_end.png\", dpi=300)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Test the model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 152,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\u001b[1m15641/15641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 1ms/step - loss: 0.0904\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0.09046255797147751"
|
|
]
|
|
},
|
|
"execution_count": 152,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# test on all test data\n",
|
|
"model_large.evaluate(X_test_preprocess.iloc[:,:-1], y_test_preprocess.iloc[:, :-1])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 153,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\u001b[1m15455/15455\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 996us/step - loss: 0.0902\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0.0901983454823494"
|
|
]
|
|
},
|
|
"execution_count": 153,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# test on non-reactive data\n",
|
|
"model_large.evaluate(X_test_preprocess[X_test_preprocess['Class'] == 0].iloc[:,:-1], y_test_preprocess[X_test_preprocess['Class'] == 0].iloc[:,:-1])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 155,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\u001b[1m186/186\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step - loss: 0.1127\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0.11247223615646362"
|
|
]
|
|
},
|
|
"execution_count": 155,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# test on reactive data\n",
|
|
"model_large.evaluate(X_test_preprocess[X_test_preprocess['Class'] == 1].iloc[:,:-1], y_test_preprocess[X_test_preprocess['Class'] == 1].iloc[:, :-1])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Save the model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 53,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Save the model\n",
|
|
"model.save(\"Barite_50_Model_additional_species.keras\")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "training",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.11"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|