model-training/POET_Training.ipynb
2025-01-15 11:55:24 +01:00

620 lines
52 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## General Information"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook is used to train a simple neural network model to predict the chemistry in the barite benchmark (50x50 grid). The training data is stored in the repository using **git large file storage** and can be downloaded after the installation of git lfs using the `git lfs pull` command.\n",
"\n",
"It is then recommended to create a Python environment using miniconda. The necessary dependencies are contained in `environment.yml` and can be installed using `conda env create -f environment.yml`.\n",
"\n",
"The data set is divided into a design and result part and consists of the iterations of a reference simulation. The design part of the data set contains the chemical concentrations at time $t$ and the result part at time $t+1$, which are to be learned by the model."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup Libraries"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running Keras in version 3.8.0\n"
]
}
],
"source": [
"import keras\n",
"print(\"Running Keras in version {}\".format(keras.__version__))\n",
"\n",
"import h5py\n",
"import numpy as np\n",
"import pandas as pd\n",
"import time\n",
"import sklearn.model_selection as sk\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define parameters"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"dtype = \"float32\"\n",
"activation = \"relu\"\n",
"\n",
"lr = 0.001\n",
"batch_size = 512\n",
"epochs = 50 # default 400 epochs\n",
"\n",
"lr_schedule = keras.optimizers.schedules.ExponentialDecay(\n",
" initial_learning_rate=lr,\n",
" decay_steps=2000,\n",
" decay_rate=0.9,\n",
" staircase=True\n",
")\n",
"\n",
"optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)\n",
"loss = keras.losses.Huber()\n",
"\n",
"sample_fraction = 0.8"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup the model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,664</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">16,512</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_2 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,548</span> │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m1,664\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m16,512\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_2 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m1,548\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">19,724</span> (77.05 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m19,724\u001b[0m (77.05 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">19,724</span> (77.05 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m19,724\u001b[0m (77.05 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = keras.Sequential(\n",
" [\n",
" keras.Input(shape = (12,), dtype = \"float32\"),\n",
" keras.layers.Dense(units = 128, activation = \"relu\", dtype = \"float32\"),\n",
" keras.layers.Dense(units = 128, activation = \"relu\", dtype = \"float32\"),\n",
" keras.layers.Dense(units = 12, dtype = \"float32\")\n",
" ]\n",
")\n",
"\n",
"model.compile(optimizer=optimizer, loss = loss)\n",
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define some functions and helper classes"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def Safelog(val):\n",
" # get range of vector\n",
" if val > 0:\n",
" return np.log10(val)\n",
" elif val < 0:\n",
" return -np.log10(-val)\n",
" else:\n",
" return 0\n",
"\n",
"def Safeexp(val):\n",
" if val > 0:\n",
" return -10 ** -val\n",
" elif val < 0:\n",
" return 10 ** val\n",
" else:\n",
" return 0\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# ? Why does the charge is using another logarithm than the other species\n",
"\n",
"func_dict_in = {\n",
" \"H\" : np.log1p,\n",
" \"O\" : np.log1p,\n",
" \"Charge\" : Safelog,\n",
" \"H_0_\" : np.log1p,\n",
" \"O_0_\" : np.log1p,\n",
" \"Ba\" : np.log1p,\n",
" \"Cl\" : np.log1p,\n",
" \"S_2_\" : np.log1p,\n",
" \"S_6_\" : np.log1p,\n",
" \"Sr\" : np.log1p,\n",
" \"Barite\" : np.log1p,\n",
" \"Celestite\" : np.log1p\n",
"}\n",
"\n",
"func_dict_out = {\n",
" \"H\" : np.expm1,\n",
" \"O\" : np.expm1,\n",
" \"Charge\" : Safeexp,\n",
" \"H_0_\" : np.expm1,\n",
" \"O_0_\" : np.expm1,\n",
" \"Ba\" : np.expm1,\n",
" \"Cl\" : np.expm1,\n",
" \"S_2_\" : np.expm1,\n",
" \"S_6_\" : np.expm1,\n",
" \"Sr\" : np.expm1,\n",
" \"Barite\" : np.expm1,\n",
" \"Celestite\" : np.expm1\n",
"}\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Read data from `.h5` file and convert it to a `pandas.DataFrame`"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"data_file = h5py.File(\"Barite_50_Data_training.h5\")\n",
"\n",
"design = data_file[\"design\"]\n",
"results = data_file[\"result\"]\n",
"\n",
"df_design = pd.DataFrame(np.array(design[\"data\"]).transpose(), columns = design[\"names\"].asstr())\n",
"df_results = pd.DataFrame(np.array(results[\"data\"]).transpose(), columns = results[\"names\"].asstr())\n",
"\n",
"data_file.close()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define Scaling and Normalization Functions"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def log_scale(df_design, df_result, func_dict):\n",
" \n",
" df_design = df_design.copy()\n",
" df_result = df_result.copy()\n",
" \n",
" for key in df_design.keys():\n",
" df_design[key] = np.vectorize(func_dict[key])(df_design[key])\n",
" df_result[key] = np.vectorize(func_dict[key])(df_result[key])\n",
" \n",
" return df_design, df_result\n",
"\n",
"# Get minimum and maximum values for each column\n",
"def get_min_max(df_design, df_result):\n",
" \n",
" min_vals_des = df_design.min()\n",
" max_vals_des = df_design.max()\n",
" \n",
" min_vals_res = df_result.min()\n",
" max_vals_res = df_result.max()\n",
"\n",
" # minimum of input and output data to get global minimum/maximum\n",
" data_min = np.minimum(min_vals_des, min_vals_res).to_dict()\n",
" data_max = np.maximum(max_vals_des, max_vals_res).to_dict()\n",
"\n",
" return data_min, data_max\n",
"\n",
"\n",
"df_design_log, df_results_log = log_scale(df_design, df_results, func_dict_in)\n",
"data_min_log, data_max_log = get_min_max(df_design_log, df_results_log)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def preprocess(data, func_dict, data_min, data_max):\n",
" data = data.copy()\n",
" for key in data.keys():\n",
" data[key] = (data[key] - data_min[key]) / (data_max[key] - data_min[key])\n",
"\n",
" return data\n",
"\n",
"def postprocess(data, func_dict, data_min, data_max):\n",
" data = data.copy()\n",
" for key in data.keys():\n",
" data[key] = data[key] * (data_max[key] - data_min[key]) + data_min[key]\n",
" data[key] = np.vectorize(func_dict[key])(data[key])\n",
" return data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Preprocess the data"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"pp_design = preprocess(df_design_log, func_dict_in, data_min_log, data_max_log)\n",
"pp_results = preprocess(df_results_log, func_dict_in, data_min_log, data_max_log)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sample the data"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = sk.train_test_split(pp_design, pp_results, test_size = 0.2)\n",
"X_train, X_val, y_train, y_val = sk.train_test_split(X_train, y_train, test_size = 0.1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the model"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 831us/step - loss: 0.0016 - val_loss: 8.9642e-07\n",
"Epoch 2/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 777us/step - loss: 1.2063e-06 - val_loss: 9.3257e-07\n",
"Epoch 3/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 808us/step - loss: 1.3414e-06 - val_loss: 7.4446e-07\n",
"Epoch 4/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 994us/step - loss: 9.5866e-07 - val_loss: 6.6027e-07\n",
"Epoch 5/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 957us/step - loss: 1.0071e-06 - val_loss: 6.1673e-07\n",
"Epoch 6/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 938us/step - loss: 8.1617e-07 - val_loss: 6.3258e-07\n",
"Epoch 7/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 945us/step - loss: 7.0918e-07 - val_loss: 6.3168e-07\n",
"Epoch 8/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 817us/step - loss: 7.2066e-07 - val_loss: 5.9542e-07\n",
"Epoch 9/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 830us/step - loss: 5.9725e-07 - val_loss: 5.8001e-07\n",
"Epoch 10/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 920us/step - loss: 7.0796e-07 - val_loss: 6.1479e-07\n",
"Epoch 11/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 1ms/step - loss: 6.1275e-07 - val_loss: 5.6376e-07\n",
"Epoch 12/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 869us/step - loss: 5.5536e-07 - val_loss: 5.7461e-07\n",
"Epoch 13/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 809us/step - loss: 6.4857e-07 - val_loss: 5.9354e-07\n",
"Epoch 14/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 819us/step - loss: 6.9492e-07 - val_loss: 6.1578e-07\n",
"Epoch 15/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 847us/step - loss: 6.0041e-07 - val_loss: 6.6684e-07\n",
"Epoch 16/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 908us/step - loss: 6.9271e-07 - val_loss: 5.5564e-07\n",
"Epoch 17/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 854us/step - loss: 7.0737e-07 - val_loss: 5.6945e-07\n",
"Epoch 18/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 835us/step - loss: 7.5155e-07 - val_loss: 5.5309e-07\n",
"Epoch 19/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 852us/step - loss: 7.4322e-07 - val_loss: 5.5036e-07\n",
"Epoch 20/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 834us/step - loss: 5.8944e-07 - val_loss: 5.6052e-07\n",
"Epoch 21/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 814us/step - loss: 6.7448e-07 - val_loss: 5.4524e-07\n",
"Epoch 22/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 811us/step - loss: 5.7292e-07 - val_loss: 5.9770e-07\n",
"Epoch 23/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 834us/step - loss: 5.7507e-07 - val_loss: 5.4916e-07\n",
"Epoch 24/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 799us/step - loss: 6.3452e-07 - val_loss: 5.4885e-07\n",
"Epoch 25/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 808us/step - loss: 5.9518e-07 - val_loss: 5.4678e-07\n",
"Epoch 26/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 822us/step - loss: 6.6424e-07 - val_loss: 5.4674e-07\n",
"Epoch 27/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 830us/step - loss: 5.9008e-07 - val_loss: 5.4434e-07\n",
"Epoch 28/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 801us/step - loss: 5.4859e-07 - val_loss: 5.4596e-07\n",
"Epoch 29/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 805us/step - loss: 4.9844e-07 - val_loss: 5.4456e-07\n",
"Epoch 30/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 815us/step - loss: 6.4763e-07 - val_loss: 5.4440e-07\n",
"Epoch 31/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 838us/step - loss: 7.3888e-07 - val_loss: 5.4584e-07\n",
"Epoch 32/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 862us/step - loss: 5.2331e-07 - val_loss: 5.4407e-07\n",
"Epoch 33/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 856us/step - loss: 6.9340e-07 - val_loss: 5.4382e-07\n",
"Epoch 34/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 857us/step - loss: 5.5593e-07 - val_loss: 5.4424e-07\n",
"Epoch 35/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 900us/step - loss: 6.2465e-07 - val_loss: 5.4352e-07\n",
"Epoch 36/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 866us/step - loss: 6.0392e-07 - val_loss: 5.4369e-07\n",
"Epoch 37/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 834us/step - loss: 6.3388e-07 - val_loss: 5.4619e-07\n",
"Epoch 38/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 813us/step - loss: 5.6506e-07 - val_loss: 5.4372e-07\n",
"Epoch 39/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 813us/step - loss: 6.9649e-07 - val_loss: 5.4339e-07\n",
"Epoch 40/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 812us/step - loss: 5.0897e-07 - val_loss: 5.4338e-07\n",
"Epoch 41/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 804us/step - loss: 6.1986e-07 - val_loss: 5.4396e-07\n",
"Epoch 42/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 825us/step - loss: 5.5556e-07 - val_loss: 5.4339e-07\n",
"Epoch 43/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 817us/step - loss: 5.9327e-07 - val_loss: 5.4372e-07\n",
"Epoch 44/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 821us/step - loss: 6.8013e-07 - val_loss: 5.4331e-07\n",
"Epoch 45/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 810us/step - loss: 5.3385e-07 - val_loss: 5.4331e-07\n",
"Epoch 46/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 817us/step - loss: 5.8341e-07 - val_loss: 5.4332e-07\n",
"Epoch 47/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 815us/step - loss: 5.8649e-07 - val_loss: 5.4331e-07\n",
"Epoch 48/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 805us/step - loss: 5.4243e-07 - val_loss: 5.4334e-07\n",
"Epoch 49/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 814us/step - loss: 6.0889e-07 - val_loss: 5.4330e-07\n",
"Epoch 50/50\n",
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 804us/step - loss: 5.9065e-07 - val_loss: 5.4327e-07\n",
"Training took 150.442538022995 seconds\n"
]
}
],
"source": [
"# measure time\n",
"start = time.time()\n",
"\n",
"history = model.fit(X_train, \n",
" y_train, \n",
" batch_size = batch_size, \n",
" epochs = epochs, \n",
" validation_data = (X_val, y_val)\n",
")\n",
"\n",
"end = time.time()\n",
"\n",
"print(\"Training took {} seconds\".format(end - start))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(history.history[\"loss\"], \"o-\", label = \"Training Loss\")\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Loss (Hubert)\")\n",
"plt.grid('on')\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test the model"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m15641/15641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 232us/step - loss: 1.0855e-06\n"
]
},
{
"data": {
"text/plain": [
"1.0228196742900764e-06"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.evaluate(X_test, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Save the model"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"# Save the model\n",
"model.save(\"Barite_50_Model_additional_species.keras\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "ai",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}