model-training/POET_Training.ipynb
2025-01-15 11:50:05 +01:00

1235 lines
125 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## General Information"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook is used to train a simple neural network model to predict the chemistry in the barite benchmark (50x50 grid). The training data is stored in the repository using **git large file storage** and can be downloaded after the installation of git lfs using the `git lfs pull` command.\n",
"\n",
"It is then recommended to create a Python environment using miniconda. The necessary dependencies are contained in `environment.yml` and can be installed using `conda env create -f environment.yml`.\n",
"\n",
"The data set is divided into a design and result part and consists of the iterations of a reference simulation. The design part of the data set contains the chemical concentrations at time $t$ and the result part at time $t+1$, which are to be learned by the model."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup Libraries"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running Keras in version 3.6.0\n"
]
}
],
"source": [
"import keras\n",
"print(\"Running Keras in version {}\".format(keras.__version__))\n",
"\n",
"import h5py\n",
"import numpy as np\n",
"import pandas as pd\n",
"import time\n",
"import sklearn.model_selection as sk\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.cluster import KMeans\n",
"from imblearn.over_sampling import SMOTE\n",
"from collections import Counter"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define parameters"
]
},
{
"cell_type": "code",
"execution_count": 116,
"metadata": {},
"outputs": [],
"source": [
"dtype = \"float32\"\n",
"activation = \"relu\"\n",
"\n",
"lr = 0.001\n",
"batch_size = 512\n",
"epochs = 50 # default 400 epochs\n",
"\n",
"lr_schedule = keras.optimizers.schedules.ExponentialDecay(\n",
" initial_learning_rate=lr,\n",
" decay_steps=2000,\n",
" decay_rate=0.9,\n",
" staircase=True\n",
")\n",
"\n",
"optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)\n",
"loss = keras.losses.Huber()\n",
"\n",
"sample_fraction = 0.8"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup the model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,664</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">16,512</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_2 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,548</span> │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m1,664\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m16,512\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_2 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m1,548\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">19,724</span> (77.05 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m19,724\u001b[0m (77.05 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">19,724</span> (77.05 KB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m19,724\u001b[0m (77.05 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model_simple = keras.Sequential(\n",
" [\n",
" keras.Input(shape = (12,), dtype = \"float32\"),\n",
" keras.layers.Dense(units = 128, activation = \"relu\", dtype = \"float32\"),\n",
" keras.layers.Dense(units = 128, activation = \"relu\", dtype = \"float32\"),\n",
" keras.layers.Dense(units = 12, dtype = \"float32\")\n",
" ]\n",
")\n",
"\n",
"model_simple.compile(optimizer=optimizer, loss = loss)\n",
"model_simple.summary()"
]
},
{
"cell_type": "code",
"execution_count": 117,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential_4\"</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential_4\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_15 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">512</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">6,656</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_16 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1024</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">525,312</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_17 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">512</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">524,800</span> │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_18 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">6,156</span> │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_15 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m6,656\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_16 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1024\u001b[0m) │ \u001b[38;5;34m525,312\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_17 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m524,800\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_18 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m6,156\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,062,924</span> (4.05 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,062,924\u001b[0m (4.05 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">1,062,924</span> (4.05 MB)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,062,924\u001b[0m (4.05 MB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model_large = keras.Sequential(\n",
" [keras.layers.Input(shape=(12,), dtype=dtype),\n",
" keras.layers.Dense(512, activation='relu', dtype=dtype),\n",
" keras.layers.Dense(1024, activation='relu', dtype=dtype),\n",
" keras.layers.Dense(512, activation='relu', dtype=dtype),\n",
" keras.layers.Dense(12, dtype=dtype)\n",
" ])\n",
"\n",
"model_large.compile(optimizer=optimizer, loss = loss)\n",
"model_large.summary()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define some functions and helper classes"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def Safelog(val):\n",
" # get range of vector\n",
" if val > 0:\n",
" return np.log10(val)\n",
" elif val < 0:\n",
" return -np.log10(-val)\n",
" else:\n",
" return 0\n",
"\n",
"def Safeexp(val):\n",
" if val > 0:\n",
" return -10 ** -val\n",
" elif val < 0:\n",
" return 10 ** val\n",
" else:\n",
" return 0\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# ? Why does the charge is using another logarithm than the other species\n",
"\n",
"func_dict_in = {\n",
" \"H\" : np.log1p,\n",
" \"O\" : np.log1p,\n",
" \"Charge\" : Safelog,\n",
" \"H_0_\" : np.log1p,\n",
" \"O_0_\" : np.log1p,\n",
" \"Ba\" : np.log1p,\n",
" \"Cl\" : np.log1p,\n",
" \"S_2_\" : np.log1p,\n",
" \"S_6_\" : np.log1p,\n",
" \"Sr\" : np.log1p,\n",
" \"Barite\" : np.log1p,\n",
" \"Celestite\" : np.log1p,\n",
"}\n",
"\n",
"func_dict_out = {\n",
" \"H\" : np.expm1,\n",
" \"O\" : np.expm1,\n",
" \"Charge\" : Safeexp,\n",
" \"H_0_\" : np.expm1,\n",
" \"O_0_\" : np.expm1,\n",
" \"Ba\" : np.expm1,\n",
" \"Cl\" : np.expm1,\n",
" \"S_2_\" : np.expm1,\n",
" \"S_6_\" : np.expm1,\n",
" \"Sr\" : np.expm1,\n",
" \"Barite\" : np.expm1,\n",
" \"Celestite\" : np.expm1,\n",
"}\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Read data from `.h5` file and convert it to a `pandas.DataFrame`"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [],
"source": [
"data_file = h5py.File(\"Barite_50_Data_training.h5\")\n",
"\n",
"design = data_file[\"design\"]\n",
"results = data_file[\"result\"]\n",
"\n",
"df_design = pd.DataFrame(np.array(design[\"data\"]).transpose(), columns = design[\"names\"].asstr())\n",
"df_results = pd.DataFrame(np.array(results[\"data\"]).transpose(), columns = results[\"names\"].asstr())\n",
"\n",
"data_file.close()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Classify each cell with kmeans"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/signer/bin/miniconda3/envs/training/lib/python3.11/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (2). Possibly due to duplicate points in X.\n",
" return fit_method(estimator, *args, **kwargs)\n"
]
}
],
"source": [
"# widget with slider for the index\n",
"\n",
"class_label_design = np.array([])\n",
"class_label_result = np.array([])\n",
"\n",
"\n",
"i = 1000\n",
"for i in range(0,1001):\n",
" field_design = np.array(df_design['Barite'][(i*2500):(i*2500+2500)]).reshape(50,50)\n",
" field_result = np.array(df_results['Barite'][(i*2500):(i*2500+2500)]).reshape(50,50)\n",
" \n",
" kmeans_design = KMeans(n_clusters=2, random_state=0).fit(field_design.reshape(-1,1))\n",
" kmeans_result = KMeans(n_clusters=2, random_state=0).fit(field_result.reshape(-1,1))\n",
" \n",
" class_label_design = np.append(class_label_design.astype(int), kmeans_design.labels_)\n",
" class_label_result = np.append(class_label_result.astype(int), kmeans_result.labels_)\n",
" \n",
"\n",
"\n",
"class_label_design = pd.DataFrame(class_label_design, columns = [\"Class\"])\n",
"class_label_result = pd.DataFrame(class_label_result, columns = [\"Class\"])\n"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [],
"source": [
"if(\"Class\" in df_design.columns and \"Class\" in df_results.columns):\n",
" print(\"Class column already exists\")\n",
"else:\n",
" df_design = pd.concat([df_design, class_label_design], axis=1)\n",
" df_results = pd.concat([df_results, class_label_design], axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.contour.QuadContourSet at 0x7aa0eb5a9e90>"
]
},
"execution_count": 89,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"i=1000\n",
"\n",
"plt.imshow(np.array(df_results['Barite'][(i*2500):(i*2500+2500)]).reshape(50,50), interpolation='bicubic', origin='lower')\n",
"plt.contour(np.array(df_results['Class'][(i*2500):(i*2500+2500)]).reshape(50,50), levels=[0.1], colors='red', origin='lower')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Split into Training and Testing datsets"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = sk.train_test_split(df_design, df_results, test_size = 0.2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Perform SMOT Sampling on dataset to balance classes"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [],
"source": [
"Counter(df_design['Class'])\n",
"oversample = SMOTE()\n",
"\n",
"design_resampled, design_classes_resampled = oversample.fit_resample(X_train.iloc[:, :-1], X_train.iloc[:, -1])\n",
"target_resampled, target_classes_resampled = oversample.fit_resample(y_train.iloc[:, :-1], y_train.iloc[:, -1])\n"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
"outputs": [],
"source": [
"X_train = pd.concat([design_resampled, design_classes_resampled], axis=1)\n",
"y_train = pd.concat([target_resampled, target_classes_resampled], axis=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define Scaling and Normalization Functions"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {},
"outputs": [],
"source": [
"def log_scale(df_design, df_result, func_dict):\n",
" \n",
" df_design = df_design.copy()\n",
" df_result = df_result.copy()\n",
" \n",
" for key in df_design.keys():\n",
" if key != \"Class\":\n",
" df_design[key] = np.vectorize(func_dict[key])(df_design[key])\n",
" df_result[key] = np.vectorize(func_dict[key])(df_result[key])\n",
" \n",
" return df_design, df_result\n",
"\n",
"# Get minimum and maximum values for each column\n",
"def get_min_max(df_design, df_result):\n",
" \n",
" min_vals_des = df_design.min()\n",
" max_vals_des = df_design.max()\n",
" \n",
" min_vals_res = df_result.min()\n",
" max_vals_res = df_result.max()\n",
"\n",
" # minimum of input and output data to get global minimum/maximum\n",
" data_min = np.minimum(min_vals_des, min_vals_res).to_dict()\n",
" data_max = np.maximum(max_vals_des, max_vals_res).to_dict()\n",
"\n",
" return data_min, data_max\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df_design_log, df_results_log = log_scale(df_design, df_results, func_dict_in)\n",
"data_min_log, data_max_log = get_min_max(df_design_log, df_results_log)"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {},
"outputs": [],
"source": [
"X_train_log, y_train_log = log_scale(X_train, y_train, func_dict_in)\n",
"X_test_log, y_test_log = log_scale(X_test, y_test, func_dict_in)"
]
},
{
"cell_type": "code",
"execution_count": 102,
"metadata": {},
"outputs": [],
"source": [
"train_min_log, train_max_log = get_min_max(X_train_log, y_train_log)\n",
"test_min_log, test_max_log = get_min_max(X_test_log, y_test_log)"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {},
"outputs": [],
"source": [
"def preprocess(data, func_dict, data_min, data_max):\n",
" data = data.copy()\n",
" for key in data.keys():\n",
" if key != \"Class\":\n",
" data[key] = (data[key] - data_min[key]) / (data_max[key] - data_min[key])\n",
"\n",
" return data\n",
"\n",
"def postprocess(data, func_dict, data_min, data_max):\n",
" data = data.copy()\n",
" for key in data.keys():\n",
" if key != \"Class\":\n",
" data[key] = data[key] * (data_max[key] - data_min[key]) + data_min[key]\n",
" data[key] = np.vectorize(func_dict[key])(data[key])\n",
" return data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Preprocess the data"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"pp_design = preprocess(df_design_log, func_dict_in, data_min_log, data_max_log)\n",
"pp_results = preprocess(df_results_log, func_dict_in, data_min_log, data_max_log)"
]
},
{
"cell_type": "code",
"execution_count": 103,
"metadata": {},
"outputs": [],
"source": [
"X_train_preprocess = preprocess(X_train_log, func_dict_in, train_min_log, train_max_log)\n",
"y_train_preprocess = preprocess(y_train_log, func_dict_in, train_min_log, train_max_log)\n",
"\n",
"X_test_preprocess = preprocess(X_test_log, func_dict_in, test_min_log, test_max_log)\n",
"y_test_preprocess = preprocess(y_test_log, func_dict_in, test_min_log, test_max_log)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sample the data"
]
},
{
"cell_type": "code",
"execution_count": 105,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_val, y_train, y_val = sk.train_test_split(X_train_preprocess, y_train_preprocess, test_size = 0.1)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Custom Loss function"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def custom_loss_H20(df_design_log, df_result_log, data_min_log, data_max_log, func_dict_out, postprocess):\n",
" df_result = postprocess(df_result_log, func_dict_out, data_min_log, data_max_log) \n",
" return keras.losses.Huber + np.sum(((df_result['H'] / df_result['O']) - 2)**2)\n",
"\n",
"def loss_wrapper(data_min_log, data_max_log, func_dict_out, postprocess):\n",
" def loss(df_design_log, df_result_log):\n",
" return custom_loss_H20(df_design_log, df_result_log, data_min_log, data_max_log, func_dict_out, postprocess)\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 106,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(3559968, 12)"
]
},
"execution_count": 106,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train.iloc[:, :-1].shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the model"
]
},
{
"cell_type": "code",
"execution_count": 118,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/50\n",
"\u001b[1m6954/6954\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m91s\u001b[0m 13ms/step - loss: 0.0070 - val_loss: 0.0066\n",
"Epoch 2/50\n",
"\u001b[1m6954/6954\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m104s\u001b[0m 15ms/step - loss: 0.0066 - val_loss: 0.0066\n",
"Epoch 3/50\n",
"\u001b[1m4644/6954\u001b[0m \u001b[32m━━━━━━━━━━━━━\u001b[0m\u001b[37m━━━━━━━\u001b[0m \u001b[1m35s\u001b[0m 15ms/step - loss: 0.0066"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[118], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# measure time\u001b[39;00m\n\u001b[1;32m 2\u001b[0m start \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[0;32m----> 4\u001b[0m history \u001b[38;5;241m=\u001b[39m model_large\u001b[38;5;241m.\u001b[39mfit(X_train\u001b[38;5;241m.\u001b[39miloc[:, :\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m], \n\u001b[1;32m 5\u001b[0m y_train\u001b[38;5;241m.\u001b[39miloc[:, :\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m], \n\u001b[1;32m 6\u001b[0m batch_size \u001b[38;5;241m=\u001b[39m batch_size, \n\u001b[1;32m 7\u001b[0m epochs \u001b[38;5;241m=\u001b[39m epochs, \n\u001b[1;32m 8\u001b[0m validation_data \u001b[38;5;241m=\u001b[39m (X_val\u001b[38;5;241m.\u001b[39miloc[:,:\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m], y_val\u001b[38;5;241m.\u001b[39miloc[:, :\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m 9\u001b[0m )\n\u001b[1;32m 11\u001b[0m end \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTraining took \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m seconds\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(end \u001b[38;5;241m-\u001b[39m start))\n",
"File \u001b[0;32m~/bin/miniconda3/envs/training/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py:117\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 115\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 116\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 117\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fn(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 119\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n",
"File \u001b[0;32m~/bin/miniconda3/envs/training/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py:320\u001b[0m, in \u001b[0;36mTensorFlowTrainer.fit\u001b[0;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq)\u001b[0m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m step, iterator \u001b[38;5;129;01min\u001b[39;00m epoch_iterator\u001b[38;5;241m.\u001b[39menumerate_epoch():\n\u001b[1;32m 319\u001b[0m callbacks\u001b[38;5;241m.\u001b[39mon_train_batch_begin(step)\n\u001b[0;32m--> 320\u001b[0m logs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_function(iterator)\n\u001b[1;32m 321\u001b[0m callbacks\u001b[38;5;241m.\u001b[39mon_train_batch_end(step, logs)\n\u001b[1;32m 322\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstop_training:\n",
"File \u001b[0;32m~/bin/miniconda3/envs/training/lib/python3.11/site-packages/tensorflow/python/util/traceback_utils.py:150\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 148\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 149\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 150\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fn(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 151\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 152\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n",
"File \u001b[0;32m~/bin/miniconda3/envs/training/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:833\u001b[0m, in \u001b[0;36mFunction.__call__\u001b[0;34m(self, *args, **kwds)\u001b[0m\n\u001b[1;32m 830\u001b[0m compiler \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mxla\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_jit_compile \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnonXla\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 832\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m OptionalXlaContext(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_jit_compile):\n\u001b[0;32m--> 833\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds)\n\u001b[1;32m 835\u001b[0m new_tracing_count \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mexperimental_get_tracing_count()\n\u001b[1;32m 836\u001b[0m without_tracing \u001b[38;5;241m=\u001b[39m (tracing_count \u001b[38;5;241m==\u001b[39m new_tracing_count)\n",
"File \u001b[0;32m~/bin/miniconda3/envs/training/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:878\u001b[0m, in \u001b[0;36mFunction._call\u001b[0;34m(self, *args, **kwds)\u001b[0m\n\u001b[1;32m 875\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lock\u001b[38;5;241m.\u001b[39mrelease()\n\u001b[1;32m 876\u001b[0m \u001b[38;5;66;03m# In this case we have not created variables on the first call. So we can\u001b[39;00m\n\u001b[1;32m 877\u001b[0m \u001b[38;5;66;03m# run the first trace but we should fail if variables are created.\u001b[39;00m\n\u001b[0;32m--> 878\u001b[0m results \u001b[38;5;241m=\u001b[39m tracing_compilation\u001b[38;5;241m.\u001b[39mcall_function(\n\u001b[1;32m 879\u001b[0m args, kwds, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_variable_creation_config\n\u001b[1;32m 880\u001b[0m )\n\u001b[1;32m 881\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_created_variables:\n\u001b[1;32m 882\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCreating variables on a non-first call to a function\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 883\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m decorated with tf.function.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m~/bin/miniconda3/envs/training/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:139\u001b[0m, in \u001b[0;36mcall_function\u001b[0;34m(args, kwargs, tracing_options)\u001b[0m\n\u001b[1;32m 137\u001b[0m bound_args \u001b[38;5;241m=\u001b[39m function\u001b[38;5;241m.\u001b[39mfunction_type\u001b[38;5;241m.\u001b[39mbind(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 138\u001b[0m flat_inputs \u001b[38;5;241m=\u001b[39m function\u001b[38;5;241m.\u001b[39mfunction_type\u001b[38;5;241m.\u001b[39munpack_inputs(bound_args)\n\u001b[0;32m--> 139\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m function\u001b[38;5;241m.\u001b[39m_call_flat( \u001b[38;5;66;03m# pylint: disable=protected-access\u001b[39;00m\n\u001b[1;32m 140\u001b[0m flat_inputs, captured_inputs\u001b[38;5;241m=\u001b[39mfunction\u001b[38;5;241m.\u001b[39mcaptured_inputs\n\u001b[1;32m 141\u001b[0m )\n",
"File \u001b[0;32m~/bin/miniconda3/envs/training/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/concrete_function.py:1322\u001b[0m, in \u001b[0;36mConcreteFunction._call_flat\u001b[0;34m(self, tensor_inputs, captured_inputs)\u001b[0m\n\u001b[1;32m 1318\u001b[0m possible_gradient_type \u001b[38;5;241m=\u001b[39m gradients_util\u001b[38;5;241m.\u001b[39mPossibleTapeGradientTypes(args)\n\u001b[1;32m 1319\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (possible_gradient_type \u001b[38;5;241m==\u001b[39m gradients_util\u001b[38;5;241m.\u001b[39mPOSSIBLE_GRADIENT_TYPES_NONE\n\u001b[1;32m 1320\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m executing_eagerly):\n\u001b[1;32m 1321\u001b[0m \u001b[38;5;66;03m# No tape is watching; skip to running the function.\u001b[39;00m\n\u001b[0;32m-> 1322\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_inference_function\u001b[38;5;241m.\u001b[39mcall_preflattened(args)\n\u001b[1;32m 1323\u001b[0m forward_backward \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_select_forward_and_backward_functions(\n\u001b[1;32m 1324\u001b[0m args,\n\u001b[1;32m 1325\u001b[0m possible_gradient_type,\n\u001b[1;32m 1326\u001b[0m executing_eagerly)\n\u001b[1;32m 1327\u001b[0m forward_function, args_with_tangents \u001b[38;5;241m=\u001b[39m forward_backward\u001b[38;5;241m.\u001b[39mforward()\n",
"File \u001b[0;32m~/bin/miniconda3/envs/training/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/atomic_function.py:216\u001b[0m, in \u001b[0;36mAtomicFunction.call_preflattened\u001b[0;34m(self, args)\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcall_preflattened\u001b[39m(\u001b[38;5;28mself\u001b[39m, args: Sequence[core\u001b[38;5;241m.\u001b[39mTensor]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m 215\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Calls with flattened tensor inputs and returns the structured output.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 216\u001b[0m flat_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcall_flat(\u001b[38;5;241m*\u001b[39margs)\n\u001b[1;32m 217\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfunction_type\u001b[38;5;241m.\u001b[39mpack_output(flat_outputs)\n",
"File \u001b[0;32m~/bin/miniconda3/envs/training/lib/python3.11/site-packages/tensorflow/python/eager/polymorphic_function/atomic_function.py:251\u001b[0m, in \u001b[0;36mAtomicFunction.call_flat\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m 249\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m record\u001b[38;5;241m.\u001b[39mstop_recording():\n\u001b[1;32m 250\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_bound_context\u001b[38;5;241m.\u001b[39mexecuting_eagerly():\n\u001b[0;32m--> 251\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_bound_context\u001b[38;5;241m.\u001b[39mcall_function(\n\u001b[1;32m 252\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mname,\n\u001b[1;32m 253\u001b[0m \u001b[38;5;28mlist\u001b[39m(args),\n\u001b[1;32m 254\u001b[0m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfunction_type\u001b[38;5;241m.\u001b[39mflat_outputs),\n\u001b[1;32m 255\u001b[0m )\n\u001b[1;32m 256\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 257\u001b[0m outputs \u001b[38;5;241m=\u001b[39m make_call_op_in_graph(\n\u001b[1;32m 258\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 259\u001b[0m \u001b[38;5;28mlist\u001b[39m(args),\n\u001b[1;32m 260\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_bound_context\u001b[38;5;241m.\u001b[39mfunction_call_options\u001b[38;5;241m.\u001b[39mas_attrs(),\n\u001b[1;32m 261\u001b[0m )\n",
"File \u001b[0;32m~/bin/miniconda3/envs/training/lib/python3.11/site-packages/tensorflow/python/eager/context.py:1552\u001b[0m, in \u001b[0;36mContext.call_function\u001b[0;34m(self, name, tensor_inputs, num_outputs)\u001b[0m\n\u001b[1;32m 1550\u001b[0m cancellation_context \u001b[38;5;241m=\u001b[39m cancellation\u001b[38;5;241m.\u001b[39mcontext()\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m cancellation_context \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1552\u001b[0m outputs \u001b[38;5;241m=\u001b[39m execute\u001b[38;5;241m.\u001b[39mexecute(\n\u001b[1;32m 1553\u001b[0m name\u001b[38;5;241m.\u001b[39mdecode(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mutf-8\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 1554\u001b[0m num_outputs\u001b[38;5;241m=\u001b[39mnum_outputs,\n\u001b[1;32m 1555\u001b[0m inputs\u001b[38;5;241m=\u001b[39mtensor_inputs,\n\u001b[1;32m 1556\u001b[0m attrs\u001b[38;5;241m=\u001b[39mattrs,\n\u001b[1;32m 1557\u001b[0m ctx\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 1558\u001b[0m )\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1560\u001b[0m outputs \u001b[38;5;241m=\u001b[39m execute\u001b[38;5;241m.\u001b[39mexecute_with_cancellation(\n\u001b[1;32m 1561\u001b[0m name\u001b[38;5;241m.\u001b[39mdecode(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mutf-8\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 1562\u001b[0m num_outputs\u001b[38;5;241m=\u001b[39mnum_outputs,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1566\u001b[0m cancellation_manager\u001b[38;5;241m=\u001b[39mcancellation_context,\n\u001b[1;32m 1567\u001b[0m )\n",
"File \u001b[0;32m~/bin/miniconda3/envs/training/lib/python3.11/site-packages/tensorflow/python/eager/execute.py:53\u001b[0m, in \u001b[0;36mquick_execute\u001b[0;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 52\u001b[0m ctx\u001b[38;5;241m.\u001b[39mensure_initialized()\n\u001b[0;32m---> 53\u001b[0m tensors \u001b[38;5;241m=\u001b[39m pywrap_tfe\u001b[38;5;241m.\u001b[39mTFE_Py_Execute(ctx\u001b[38;5;241m.\u001b[39m_handle, device_name, op_name,\n\u001b[1;32m 54\u001b[0m inputs, attrs, num_outputs)\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m core\u001b[38;5;241m.\u001b[39m_NotOkStatusException \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"# measure time\n",
"start = time.time()\n",
"\n",
"history = model_large.fit(X_train.iloc[:, :-1], \n",
" y_train.iloc[:, :-1], \n",
" batch_size = batch_size, \n",
" epochs = epochs, \n",
" validation_data = (X_val.iloc[:,:-1], y_val.iloc[:, :-1])\n",
")\n",
"\n",
"end = time.time()\n",
"\n",
"print(\"Training took {} seconds\".format(end - start))"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(history.history[\"loss\"], \"o-\", label = \"Training Loss\")\n",
"plt.xlabel(\"Epoch\")\n",
"# plt.yscale('log')\n",
"plt.ylabel(\"Loss (Huber)\")\n",
"plt.grid('on')\n",
"\n",
"plt.savefig(\"loss_all.png\", dpi=300)\n"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(history.history[\"loss\"][1:], \"o-\", label = \"Training Loss\")\n",
"plt.xlabel(\"Epoch\")\n",
"# plt.yscale('log')\n",
"plt.ylabel(\"Loss (Huber)\")\n",
"plt.grid('on')\n",
"plt.savefig(\"loss_1_to_end.png\", dpi=300)\n"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>H</th>\n",
" <th>O</th>\n",
" <th>Charge</th>\n",
" <th>H_0_</th>\n",
" <th>O_0_</th>\n",
" <th>Ba</th>\n",
" <th>Cl</th>\n",
" <th>S_2_</th>\n",
" <th>S_6_</th>\n",
" <th>Sr</th>\n",
" <th>Barite</th>\n",
" <th>Celestite</th>\n",
" <th>Class</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>111.012434</td>\n",
" <td>55.508192</td>\n",
" <td>-7.779554e-09</td>\n",
" <td>2.697041e-26</td>\n",
" <td>2.210590e-15</td>\n",
" <td>2.041069e-02</td>\n",
" <td>4.082138e-02</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000494</td>\n",
" <td>0.000494</td>\n",
" <td>0.001</td>\n",
" <td>1.000000</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>111.012434</td>\n",
" <td>55.508427</td>\n",
" <td>-4.736083e-09</td>\n",
" <td>1.446346e-26</td>\n",
" <td>2.473481e-15</td>\n",
" <td>1.094567e-02</td>\n",
" <td>2.189133e-02</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000553</td>\n",
" <td>0.000553</td>\n",
" <td>0.001</td>\n",
" <td>1.000000</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>111.012434</td>\n",
" <td>55.508691</td>\n",
" <td>-1.311169e-09</td>\n",
" <td>3.889826e-28</td>\n",
" <td>2.769320e-15</td>\n",
" <td>2.943745e-04</td>\n",
" <td>5.887491e-04</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000619</td>\n",
" <td>0.000619</td>\n",
" <td>0.001</td>\n",
" <td>1.000000</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>111.012434</td>\n",
" <td>55.508698</td>\n",
" <td>-1.220023e-09</td>\n",
" <td>1.442658e-29</td>\n",
" <td>2.777193e-15</td>\n",
" <td>1.091776e-05</td>\n",
" <td>2.183551e-05</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000620</td>\n",
" <td>0.000620</td>\n",
" <td>0.001</td>\n",
" <td>1.000000</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>111.012434</td>\n",
" <td>55.508699</td>\n",
" <td>-1.216643e-09</td>\n",
" <td>5.350528e-31</td>\n",
" <td>2.777485e-15</td>\n",
" <td>4.049176e-07</td>\n",
" <td>8.098352e-07</td>\n",
" <td>0.000000e+00</td>\n",
" <td>0.000620</td>\n",
" <td>0.000620</td>\n",
" <td>0.001</td>\n",
" <td>1.000000</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2502495</th>\n",
" <td>111.012434</td>\n",
" <td>55.507488</td>\n",
" <td>3.573728e-09</td>\n",
" <td>5.424062e-145</td>\n",
" <td>1.375204e-10</td>\n",
" <td>9.953520e-07</td>\n",
" <td>2.266555e-03</td>\n",
" <td>5.509534e-149</td>\n",
" <td>0.000318</td>\n",
" <td>0.001450</td>\n",
" <td>0.001</td>\n",
" <td>1.000014</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2502496</th>\n",
" <td>111.012434</td>\n",
" <td>55.507501</td>\n",
" <td>3.494007e-09</td>\n",
" <td>2.011675e-146</td>\n",
" <td>1.377139e-10</td>\n",
" <td>9.817216e-07</td>\n",
" <td>2.217997e-03</td>\n",
" <td>2.043375e-150</td>\n",
" <td>0.000321</td>\n",
" <td>0.001429</td>\n",
" <td>0.001</td>\n",
" <td>1.000010</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2502497</th>\n",
" <td>111.012434</td>\n",
" <td>55.507512</td>\n",
" <td>3.429764e-09</td>\n",
" <td>7.460897e-148</td>\n",
" <td>1.377819e-10</td>\n",
" <td>9.706451e-07</td>\n",
" <td>2.179066e-03</td>\n",
" <td>7.578467e-152</td>\n",
" <td>0.000324</td>\n",
" <td>0.001412</td>\n",
" <td>0.001</td>\n",
" <td>1.000006</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2502498</th>\n",
" <td>111.012434</td>\n",
" <td>55.507520</td>\n",
" <td>3.381745e-09</td>\n",
" <td>2.767237e-149</td>\n",
" <td>1.371144e-10</td>\n",
" <td>9.621074e-07</td>\n",
" <td>2.149820e-03</td>\n",
" <td>2.810844e-153</td>\n",
" <td>0.000326</td>\n",
" <td>0.001400</td>\n",
" <td>0.001</td>\n",
" <td>1.000004</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2502499</th>\n",
" <td>111.012434</td>\n",
" <td>55.507525</td>\n",
" <td>3.348864e-09</td>\n",
" <td>5.321610e-151</td>\n",
" <td>1.376026e-10</td>\n",
" <td>9.564401e-07</td>\n",
" <td>2.129912e-03</td>\n",
" <td>5.405468e-155</td>\n",
" <td>0.000327</td>\n",
" <td>0.001391</td>\n",
" <td>0.001</td>\n",
" <td>1.000001</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2502500 rows × 13 columns</p>\n",
"</div>"
],
"text/plain": [
" H O Charge H_0_ O_0_ \\\n",
"0 111.012434 55.508192 -7.779554e-09 2.697041e-26 2.210590e-15 \n",
"1 111.012434 55.508427 -4.736083e-09 1.446346e-26 2.473481e-15 \n",
"2 111.012434 55.508691 -1.311169e-09 3.889826e-28 2.769320e-15 \n",
"3 111.012434 55.508698 -1.220023e-09 1.442658e-29 2.777193e-15 \n",
"4 111.012434 55.508699 -1.216643e-09 5.350528e-31 2.777485e-15 \n",
"... ... ... ... ... ... \n",
"2502495 111.012434 55.507488 3.573728e-09 5.424062e-145 1.375204e-10 \n",
"2502496 111.012434 55.507501 3.494007e-09 2.011675e-146 1.377139e-10 \n",
"2502497 111.012434 55.507512 3.429764e-09 7.460897e-148 1.377819e-10 \n",
"2502498 111.012434 55.507520 3.381745e-09 2.767237e-149 1.371144e-10 \n",
"2502499 111.012434 55.507525 3.348864e-09 5.321610e-151 1.376026e-10 \n",
"\n",
" Ba Cl S_2_ S_6_ Sr \\\n",
"0 2.041069e-02 4.082138e-02 0.000000e+00 0.000494 0.000494 \n",
"1 1.094567e-02 2.189133e-02 0.000000e+00 0.000553 0.000553 \n",
"2 2.943745e-04 5.887491e-04 0.000000e+00 0.000619 0.000619 \n",
"3 1.091776e-05 2.183551e-05 0.000000e+00 0.000620 0.000620 \n",
"4 4.049176e-07 8.098352e-07 0.000000e+00 0.000620 0.000620 \n",
"... ... ... ... ... ... \n",
"2502495 9.953520e-07 2.266555e-03 5.509534e-149 0.000318 0.001450 \n",
"2502496 9.817216e-07 2.217997e-03 2.043375e-150 0.000321 0.001429 \n",
"2502497 9.706451e-07 2.179066e-03 7.578467e-152 0.000324 0.001412 \n",
"2502498 9.621074e-07 2.149820e-03 2.810844e-153 0.000326 0.001400 \n",
"2502499 9.564401e-07 2.129912e-03 5.405468e-155 0.000327 0.001391 \n",
"\n",
" Barite Celestite Class \n",
"0 0.001 1.000000 0 \n",
"1 0.001 1.000000 0 \n",
"2 0.001 1.000000 0 \n",
"3 0.001 1.000000 0 \n",
"4 0.001 1.000000 0 \n",
"... ... ... ... \n",
"2502495 0.001 1.000014 0 \n",
"2502496 0.001 1.000010 0 \n",
"2502497 0.001 1.000006 0 \n",
"2502498 0.001 1.000004 0 \n",
"2502499 0.001 1.000001 0 \n",
"\n",
"[2502500 rows x 13 columns]"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_design"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test the model"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m15641/15641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 336us/step - loss: 6.6414e-07\n"
]
},
{
"data": {
"text/plain": [
"8.585521982240607e-07"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# test on all test data\n",
"model.evaluate(X_test.iloc[:,:-1], y_test)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m15454/15454\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 331us/step - loss: 2.7927e-07\n"
]
},
{
"data": {
"text/plain": [
"3.939527175589319e-07"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# test on non-reactive data\n",
"model.evaluate(X_test[X_test['Class'] == 0].iloc[:,:-1], y_test[X_test['Class'] == 0])"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1m188/188\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 400us/step - loss: 3.3173e-05\n"
]
},
{
"data": {
"text/plain": [
"3.921399184037e-05"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# test on reactive data\n",
"model.evaluate(X_test[X_test['Class'] == 1].iloc[:,:-1], y_test[X_test['Class'] == 1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Save the model"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"# Save the model\n",
"model.save(\"Barite_50_Model_additional_species.keras\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "training",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}