mirror of
https://git.gfz-potsdam.de/naaice/model-training.git
synced 2025-12-15 20:48:21 +01:00
1046 lines
99 KiB
Plaintext
1046 lines
99 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## General Information"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"This notebook is used to train a simple neural network model to predict the chemistry in the barite benchmark (50x50 grid). The training data is stored in the repository using **git large file storage** and can be downloaded after the installation of git lfs using the `git lfs pull` command.\n",
|
||
"\n",
|
||
"It is then recommended to create a Python environment using miniconda. The necessary dependencies are contained in `environment.yml` and can be installed using `conda env create -f environment.yml`.\n",
|
||
"\n",
|
||
"The data set is divided into a design and result part and consists of the iterations of a reference simulation. The design part of the data set contains the chemical concentrations at time $t$ and the result part at time $t+1$, which are to be learned by the model."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Setup Libraries"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"2025-01-14 17:26:34.798886: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
|
||
"2025-01-14 17:26:34.825591: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
|
||
"To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Running Keras in version 3.6.0\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import keras\n",
|
||
"print(\"Running Keras in version {}\".format(keras.__version__))\n",
|
||
"\n",
|
||
"import h5py\n",
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"import time\n",
|
||
"import sklearn.model_selection as sk\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"from sklearn.cluster import KMeans"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Define parameters"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"dtype = \"float32\"\n",
|
||
"activation = \"relu\"\n",
|
||
"\n",
|
||
"lr = 0.001\n",
|
||
"batch_size = 512\n",
|
||
"epochs = 50 # default 400 epochs\n",
|
||
"\n",
|
||
"lr_schedule = keras.optimizers.schedules.ExponentialDecay(\n",
|
||
" initial_learning_rate=lr,\n",
|
||
" decay_steps=2000,\n",
|
||
" decay_rate=0.9,\n",
|
||
" staircase=True\n",
|
||
")\n",
|
||
"\n",
|
||
"optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)\n",
|
||
"loss = keras.losses.Huber()\n",
|
||
"\n",
|
||
"sample_fraction = 0.8"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Setup the model"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential\"</span>\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1mModel: \"sequential\"\u001b[0m\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
|
||
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃\n",
|
||
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
|
||
"│ dense (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,664</span> │\n",
|
||
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
|
||
"│ dense_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">128</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">16,512</span> │\n",
|
||
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
|
||
"│ dense_2 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">12</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">1,548</span> │\n",
|
||
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
|
||
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
|
||
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
|
||
"│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m1,664\u001b[0m │\n",
|
||
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
|
||
"│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m16,512\u001b[0m │\n",
|
||
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
|
||
"│ dense_2 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m1,548\u001b[0m │\n",
|
||
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">19,724</span> (77.05 KB)\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m19,724\u001b[0m (77.05 KB)\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">19,724</span> (77.05 KB)\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m19,724\u001b[0m (77.05 KB)\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
|
||
"</pre>\n"
|
||
],
|
||
"text/plain": [
|
||
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"model = keras.Sequential(\n",
|
||
" [\n",
|
||
" keras.Input(shape = (12,), dtype = \"float32\"),\n",
|
||
" keras.layers.Dense(units = 128, activation = \"relu\", dtype = \"float32\"),\n",
|
||
" keras.layers.Dense(units = 128, activation = \"relu\", dtype = \"float32\"),\n",
|
||
" keras.layers.Dense(units = 12, dtype = \"float32\")\n",
|
||
" ]\n",
|
||
")\n",
|
||
"\n",
|
||
"model.compile(optimizer=optimizer, loss = loss)\n",
|
||
"model.summary()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Define some functions and helper classes"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def Safelog(val):\n",
|
||
" # get range of vector\n",
|
||
" if val > 0:\n",
|
||
" return np.log10(val)\n",
|
||
" elif val < 0:\n",
|
||
" return -np.log10(-val)\n",
|
||
" else:\n",
|
||
" return 0\n",
|
||
"\n",
|
||
"def Safeexp(val):\n",
|
||
" if val > 0:\n",
|
||
" return -10 ** -val\n",
|
||
" elif val < 0:\n",
|
||
" return 10 ** val\n",
|
||
" else:\n",
|
||
" return 0\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 55,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# ? Why does the charge is using another logarithm than the other species\n",
|
||
"\n",
|
||
"func_dict_in = {\n",
|
||
" \"H\" : np.log1p,\n",
|
||
" \"O\" : np.log1p,\n",
|
||
" \"Charge\" : Safelog,\n",
|
||
" \"H_0_\" : np.log1p,\n",
|
||
" \"O_0_\" : np.log1p,\n",
|
||
" \"Ba\" : np.log1p,\n",
|
||
" \"Cl\" : np.log1p,\n",
|
||
" \"S_2_\" : np.log1p,\n",
|
||
" \"S_6_\" : np.log1p,\n",
|
||
" \"Sr\" : np.log1p,\n",
|
||
" \"Barite\" : np.log1p,\n",
|
||
" \"Celestite\" : np.log1p,\n",
|
||
"}\n",
|
||
"\n",
|
||
"func_dict_out = {\n",
|
||
" \"H\" : np.expm1,\n",
|
||
" \"O\" : np.expm1,\n",
|
||
" \"Charge\" : Safeexp,\n",
|
||
" \"H_0_\" : np.expm1,\n",
|
||
" \"O_0_\" : np.expm1,\n",
|
||
" \"Ba\" : np.expm1,\n",
|
||
" \"Cl\" : np.expm1,\n",
|
||
" \"S_2_\" : np.expm1,\n",
|
||
" \"S_6_\" : np.expm1,\n",
|
||
" \"Sr\" : np.expm1,\n",
|
||
" \"Barite\" : np.expm1,\n",
|
||
" \"Celestite\" : np.expm1,\n",
|
||
"}\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Read data from `.h5` file and convert it to a `pandas.DataFrame`"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 49,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"data_file = h5py.File(\"Barite_50_Data_training.h5\")\n",
|
||
"\n",
|
||
"design = data_file[\"design\"]\n",
|
||
"results = data_file[\"result\"]\n",
|
||
"\n",
|
||
"df_design = pd.DataFrame(np.array(design[\"data\"]).transpose(), columns = design[\"names\"].asstr())\n",
|
||
"df_results = pd.DataFrame(np.array(results[\"data\"]).transpose(), columns = results[\"names\"].asstr())\n",
|
||
"\n",
|
||
"data_file.close()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Classify each cell with kmeans"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 50,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/home/signer/bin/miniconda3/envs/training/lib/python3.11/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (2). Possibly due to duplicate points in X.\n",
|
||
" return fit_method(estimator, *args, **kwargs)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# widget with slider for the index\n",
|
||
"\n",
|
||
"class_label = np.array([])\n",
|
||
"i = 1000\n",
|
||
"for i in range(0,1001):\n",
|
||
" field = np.array(df_design['Barite'][(i*2500):(i*2500+2500)]).reshape(50,50)\n",
|
||
" kmeans = KMeans(n_clusters=2, random_state=0).fit(field.reshape(-1,1))\n",
|
||
" class_label = np.append(class_label.astype(int), kmeans.labels_)\n",
|
||
"\n",
|
||
"\n",
|
||
"class_label = pd.DataFrame(class_label, columns = [\"Class\"])\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 30,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"<matplotlib.image.AxesImage at 0x7b603017f510>"
|
||
]
|
||
},
|
||
"execution_count": 30,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGfCAYAAAD22G0fAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAFkFJREFUeJzt3V9s1uXd+PHPjcAtaNvJnC0NzAdjHzfHg78IjkBUmEp/Icbo48kyjGHxRAUMDQcoeiDbQYuYEF2YLLrFmSw+7GD+O5iGJmrZQkwKQiSY+MsShk2k61ywrYhF8Pod+HjPDoYUip8Cr1fyPbiv7/e+e3mJ99ur9x8qpZQSAJBgXPYEADh/iRAAaUQIgDQiBEAaEQIgjQgBkEaEAEgjQgCkESEA0ogQAGnGn6kHfuqpp+Lxxx+P/fv3xw9+8IN44okn4oYbbvja+33++efxwQcfRF1dXVQqlTM1PQDOkFJKDA4ORnNzc4wb9zV7nXIGbN68uUyYMKE888wz5d133y0rV64sF110Udm3b9/X3renp6dEhMPhcDjO8qOnp+drn/MrpYz+F5jOnTs3rr322ti0aVNt7Pvf/37ccccd0dHRccL79vf3x7e+9a3Y9/Z/RP3F5/5vC//7P/8rewoAo+pIfBZ/jj/GRx99FA0NDSe8dtR/HXf48OHYsWNHPPTQQ8PGW1tbY9u2bcdcPzQ0FENDQ7Xbg4ODERFRf/G4qK879yM0vjIhewoAo+t/tzYn85LKqD/Lf/jhh3H06NFobGwcNt7Y2Bi9vb3HXN/R0RENDQ21Y/r06aM9JQDGqDO21fjXApZSjlvFNWvWRH9/f+3o6ek5U1MCYIwZ9V/HXXrppXHBBRccs+vp6+s7ZncUEVGtVqNarY72NAA4C4z6TmjixIkxe/bs6OzsHDbe2dkZ8+fPH+0fB8BZ7Ix8TmjVqlVx9913x5w5c2LevHnx9NNPx/vvvx/33XffmfhxAJylzkiEfvzjH8c//vGP+PnPfx779++PmTNnxh//+Me4/PLLz8SPA+AsdUY+J3Q6BgYGoqGhIQ78vyvOi7do/9/m/5M9BYBRdaR8Fm/Gy9Hf3x/19fUnvPbcf5YHYMwSIQDSiBAAaUQIgDQiBEAaEQIgjQgBkEaEAEgjQgCkESEA0ogQAGnOyBeYMpzvhwM4PjshANKIEABpRAiANCIEQBoRAiCNCAGQRoQASONzQqPA54AATo2dEABpRAiANCIEQBoRAiCNCAGQRoQASCNCAKQRIQDSiBAAaUQIgDQiBEAaEQIgjQgBkMa3aH+Fb8MG+GbZCQGQRoQASCNCAKQRIQDSiBAAaUQIgDQiBECa8+pzQj4HBDC22AkBkEaEAEgjQgCkESEA0ogQAGlECIA0Y/Yt2v/9n/8V4ysTsqcBwBlkJwRAGhECII0IAZBGhABII0IApBEhANKIEABpRAiANCIEQBoRAiCNCAGQRoQASCNCAKQRIQDSjDhCW7dujdtuuy2am5ujUqnESy+9NOx8KSXWrl0bzc3NMWnSpFi4cGHs2bNntOYLwDlkxBE6ePBgXHPNNbFx48bjnl+/fn1s2LAhNm7cGN3d3dHU1BSLFi2KwcHB054sAOeWEf+ldosXL47Fixcf91wpJZ544ol45JFH4s4774yIiOeeey4aGxvj+eefj3vvvff0ZgvAOWVUXxPau3dv9Pb2Rmtra22sWq3GggULYtu2bce9z9DQUAwMDAw7ADg/jGqEent7IyKisbFx2HhjY2Pt3L/q6OiIhoaG2jF9+vTRnBIAY9gZeXdcpVIZdruUcszYl9asWRP9/f21o6en50xMCYAxaMSvCZ1IU1NTRHyxI5o6dWptvK+v75jd0Zeq1WpUq9XRnAYAZ4lR3QnNmDEjmpqaorOzszZ2+PDh6Orqivnz54/mjwLgHDDindDHH38cf/nLX2q39+7dG7t27YopU6bEd7/73Whra4v29vZoaWmJlpaWaG9vj8mTJ8eSJUtGdeIAnP1GHKHt27fHj370o9rtVatWRUTE0qVL47e//W2sXr06Dh06FMuWLYsDBw7E3LlzY8uWLVFXVzd6swbgnFAppZTsSXzVwMBANDQ0xMK4PcZXJmRPB4AROlI+izfj5ejv74/6+voTXuu74wBII0IApBEhANKIEABpRAiANCIEQBoRAiCNCAGQRoQASCNCAKQRIQDSiBAAaUQIgDQiBEAaEQIgjQgBkEaEAEgjQgCkESEA0ogQAGlECIA0IgRAGhECII0IAZBGhABII0IApBEhANKIEABpRAiANCIEQBoRAiCNCAGQRoQASCNCAKQRIQDSiBAAaUQIgDQiBEAaEQIgjQgBkEaEAEgjQgCkESEA0ogQAGlECIA0IgRAGhECII0IAZBGhABII0IApBEhANKIEABpRAiANCIEQBoRAiCNCAGQRoQASCNCAKQRIQDSiBAAaUQIgDQiBEAaEQIgzYgi1NHREdddd13U1dXFZZddFnfccUe89957w64ppcTatWujubk5Jk2aFAsXLow9e/aM6qQBODeMKEJdXV2xfPnyeOutt6KzszOOHDkSra2tcfDgwdo169evjw0bNsTGjRuju7s7mpqaYtGiRTE4ODjqkwfg7FYppZRTvfPf//73uOyyy6KrqytuvPHGKKVEc3NztLW1xYMPPhgREUNDQ9HY2BiPPfZY3HvvvV/7mAMDA9HQ0BAL4/YYX5lwqlMDIMmR8lm8GS9Hf39/1NfXn/Da03pNqL+/PyIipkyZEhERe/fujd7e3mhtba1dU61WY8GCBbFt27bjPsbQ0FAMDAwMOwA4P5xyhEopsWrVqrj++utj5syZERHR29sbERGNjY3Drm1sbKyd+1cdHR3R0NBQO6ZPn36qUwLgLHPKEVqxYkW888478T//8z/HnKtUKsNul1KOGfvSmjVror+/v3b09PSc6pQAOMuMP5U7PfDAA/HKK6/E1q1bY9q0abXxpqamiPhiRzR16tTaeF9f3zG7oy9Vq9WoVqunMg0AznIj2gmVUmLFihXxwgsvxOuvvx4zZswYdn7GjBnR1NQUnZ2dtbHDhw9HV1dXzJ8/f3RmDMA5Y0Q7oeXLl8fzzz8fL7/8ctTV1dVe52loaIhJkyZFpVKJtra2aG9vj5aWlmhpaYn29vaYPHlyLFmy5Iz8AwBw9hpRhDZt2hQREQsXLhw2/uyzz8ZPf/rTiIhYvXp1HDp0KJYtWxYHDhyIuXPnxpYtW6Kurm5UJgzAueO0Pid0JvicEMDZ7Rv7nBAAnA4RAiCNCAGQRoQASCNCAKQRIQDSiBAAaUQIgDQiBEAaEQIgjQgBkEaEAEgjQgCkESEA0ogQAGlECIA0IgRAGhECII0IAZBGhABII0IApBEhANKIEABpRAiANCIEQBoRAiCNCAGQRoQASCNCAKQRIQDSiBAAaUQIgDQiBEAaEQIgjQgBkEaEAEgjQgCkESEA0ogQAGlECIA0IgRAGhECII0IAZBGhABII0IApBEhANKIEABpRAiANCIEQBoRAiCNCAGQRoQASCNCAKQRIQDSiBAAaUQIgDQiBEAaEQIgjQgBkEaEAEgjQgCkESEA0ogQAGlGFKFNmzbFrFmzor6+Purr62PevHnx6quv1s6XUmLt2rXR3NwckyZNioULF8aePXtGfdIAnBtGFKFp06bFunXrYvv27bF9+/a46aab4vbbb6+FZv369bFhw4bYuHFjdHd3R1NTUyxatCgGBwfPyOQBOLtVSinldB5gypQp8fjjj8c999wTzc3N0dbWFg8++GBERAwNDUVjY2M89thjce+9957U4w0MDERDQ0MsjNtjfGXC6UwNgARHymfxZrwc/f39UV9ff8JrT/k1oaNHj8bmzZvj4MGDMW/evNi7d2/09vZGa2tr7ZpqtRoLFiyIbdu2/dvHGRoaioGBgWEHAOeHEUdo9+7dcfHFF0e1Wo377rsvXnzxxbj66qujt7c3IiIaGxuHXd/Y2Fg7dzwdHR3R0NBQO6ZPnz7SKQFwlhpxhK666qrYtWtXvPXWW3H//ffH0qVL4913362dr1Qqw64vpRwz9lVr1qyJ/v7+2tHT0zPSKQFwlho/0jtMnDgxrrzyyoiImDNnTnR3d8eTTz5Zex2ot7c3pk6dWru+r6/vmN3RV1Wr1ahWqyOdBgDngNP+nFApJYaGhmLGjBnR1NQUnZ2dtXOHDx+Orq6umD9//un+GADOQSPaCT388MOxePHimD59egwODsbmzZvjzTffjNdeey0qlUq0tbVFe3t7tLS0REtLS7S3t8fkyZNjyZIlZ2r+AJzFRhShv/3tb3H33XfH/v37o6GhIWbNmhWvvfZaLFq0KCIiVq9eHYcOHYply5bFgQMHYu7cubFly5aoq6s7I5MH4Ox22p8TGm0+JwRwdvtGPicEAKdLhABII0IApBEhANKIEABpRAiANCIEQBoRAiCNCAGQRoQASCNCAKQRIQDSiBAAaUQIgDQiBEAaEQIgjQgBkEaEAEgjQgCkESEA0ogQAGlECIA0IgRAGhECII0IAZBGhABII0IApBEhANKIEABpRAiANCIEQBoRAiCNCAGQRoQASCNCAKQRIQDSiBAAaUQIgDQiBEAaEQIgjQgBkEaEAEgjQgCkESEA0ogQAGlECIA0IgRAGhECII0IAZBGhABII0IApBEhANKIEABpRAiANCIEQBoRAiCNCAGQRoQASCNCAKQRIQDSiBAAaUQIgDQiBECa04pQR0dHVCqVaGtrq42VUmLt2rXR3NwckyZNioULF8aePXtOd54AnINOOULd3d3x9NNPx6xZs4aNr1+/PjZs2BAbN26M7u7uaGpqikWLFsXg4OBpTxaAc8spRejjjz+Ou+66K5555pm45JJLauOllHjiiSfikUceiTvvvDNmzpwZzz33XHzyySfx/PPPj9qkATg3nFKEli9fHrfeemvccsstw8b37t0bvb290draWhurVquxYMGC2LZt23Efa2hoKAYGBoYdAJwfxo/0Dps3b4633347uru7jznX29sbERGNjY3DxhsbG2Pfvn3HfbyOjo742c9+NtJpAHAOGNFOqKenJ1auXBm/+93v4sILL/y311UqlWG3SynHjH1pzZo10d/fXzt6enpGMiUAzmIj2gnt2LEj+vr6Yvbs2bWxo0ePxtatW2Pjxo3x3nvvRcQXO6KpU6fWrunr6ztmd/SlarUa1Wr1VOYOwFluRDuhm2++OXbv3h27du2qHXPmzIm77rordu3aFVdccUU0NTVFZ2dn7T6HDx+Orq6umD9//qhPHoCz24h2QnV1dTFz5sxhYxdddFF8+9vfro23tbVFe3t7tLS0REtLS7S3t8fkyZNjyZIlozdrAM4JI35jwtdZvXp1HDp0KJYtWxYHDhyIuXPnxpYtW6Kurm60fxQAZ7lKKaVkT+KrBgYGoqGhIRbG7TG+MiF7OgCM0JHyWbwZL0d/f3/U19ef8FrfHQdAGhECII0IAZBGhABII0IApBEhANKIEABpRAiANCIEQBoRAiCNCAGQRoQASCNCAKQRIQDSiBAAaUQIgDQiBEAaEQIgjQgBkEaEAEgjQgCkESEA0ogQAGlECIA0IgRAGhECII0IAZBGhABII0IApBEhANKIEABpRAiANCIEQBoRAiCNCAGQRoQASCNCAKQRIQDSiBAAaUQIgDQiBEAaEQIgjQgBkEaEAEgjQgCkESEA0ogQAGlECIA0IgRAGhECII0IAZBGhABII0IApBEhANKIEABpRAiANCIEQBoRAiCNCAGQRoQASCNCAKQRIQDSjM+ewL8qpURExJH4LKIkTwaAETsSn0XEP5/PT2TMRWhwcDAiIv4cf0yeCQCnY3BwMBoaGk54TaWcTKq+QZ9//nl88MEHUVdXF5VKJQYGBmL69OnR09MT9fX12dMbs6zTybFOJ8c6nRzrdHyllBgcHIzm5uYYN+7Er/qMuZ3QuHHjYtq0aceM19fX+5d8EqzTybFOJ8c6nRzrdKyv2wF9yRsTAEgjQgCkGfMRqlar8eijj0a1Ws2eyphmnU6OdTo51unkWKfTN+bemADA+WPM74QAOHeJEABpRAiANCIEQJoxH6GnnnoqZsyYERdeeGHMnj07/vSnP2VPKdXWrVvjtttui+bm5qhUKvHSSy8NO19KibVr10Zzc3NMmjQpFi5cGHv27MmZbJKOjo647rrroq6uLi677LK444474r333ht2jXWK2LRpU8yaNav2Qct58+bFq6++WjtvjY6vo6MjKpVKtLW11cas1akb0xH6/e9/H21tbfHII4/Ezp0744YbbojFixfH+++/nz21NAcPHoxrrrkmNm7ceNzz69evjw0bNsTGjRuju7s7mpqaYtGiRbXv5DsfdHV1xfLly+Ott96Kzs7OOHLkSLS2tsbBgwdr11iniGnTpsW6deti+/btsX379rjpppvi9ttvrz15WqNjdXd3x9NPPx2zZs0aNm6tTkMZw374wx+W++67b9jY9773vfLQQw8lzWhsiYjy4osv1m5//vnnpampqaxbt6429umnn5aGhobyq1/9KmGGY0NfX1+JiNLV1VVKsU4ncskll5Rf//rX1ug4BgcHS0tLS+ns7CwLFiwoK1euLKX483S6xuxO6PDhw7Fjx45obW0dNt7a2hrbtm1LmtXYtnfv3ujt7R22ZtVqNRYsWHBer1l/f39EREyZMiUirNPxHD16NDZv3hwHDx6MefPmWaPjWL58edx6661xyy23DBu3VqdnzH2B6Zc+/PDDOHr0aDQ2Ng4bb2xsjN7e3qRZjW1frsvx1mzfvn0ZU0pXSolVq1bF9ddfHzNnzowI6/RVu3fvjnnz5sWnn34aF198cbz44otx9dVX1548rdEXNm/eHG+//XZ0d3cfc86fp9MzZiP0pUqlMux2KeWYMYazZv+0YsWKeOedd+LPf/7zMeesU8RVV10Vu3btio8++ij+8Ic/xNKlS6Orq6t23hpF9PT0xMqVK2PLli1x4YUX/tvrrNWpGbO/jrv00kvjggsuOGbX09fXd8z/cfCFpqamiAhr9r8eeOCBeOWVV+KNN94Y9teDWKd/mjhxYlx55ZUxZ86c6OjoiGuuuSaefPJJa/QVO3bsiL6+vpg9e3aMHz8+xo8fH11dXfGLX/wixo8fX1sPa3VqxmyEJk6cGLNnz47Ozs5h452dnTF//vykWY1tM2bMiKampmFrdvjw4ejq6jqv1qyUEitWrIgXXnghXn/99ZgxY8aw89bp3yulxNDQkDX6iptvvjl2794du3btqh1z5syJu+66K3bt2hVXXHGFtTodee+J+HqbN28uEyZMKL/5zW/Ku+++W9ra2spFF11U/vrXv2ZPLc3g4GDZuXNn2blzZ4mIsmHDhrJz586yb9++Ukop69atKw0NDeWFF14ou3fvLj/5yU/K1KlTy8DAQPLMvzn3339/aWhoKG+++WbZv39/7fjkk09q11inUtasWVO2bt1a9u7dW955553y8MMPl3HjxpUtW7aUUqzRiXz13XGlWKvTMaYjVEopv/zlL8vll19eJk6cWK699tra22zPV2+88UaJiGOOpUuXllK+eLvoo48+Wpqamkq1Wi033nhj2b17d+6kv2HHW5+IKM8++2ztGutUyj333FP7b+s73/lOufnmm2sBKsUanci/RshanTp/lQMAacbsa0IAnPtECIA0IgRAGhECII0IAZBGhABII0IApBEhANKIEABpRAiANCIEQBoRAiDN/wf760AGi+dbEwAAAABJRU5ErkJggg==",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"i = 1000\n",
|
||
"plt.imshow(class_label[(i*2500):(i*2500+2500)].reshape(50,50)) "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 52,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Class column already exists\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"if(\"Class\" in df_design.columns):\n",
|
||
" print(\"Class column already exists\")\n",
|
||
"else:\n",
|
||
" df_design = pd.concat([df_design, class_label], axis=1)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 53,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>H</th>\n",
|
||
" <th>O</th>\n",
|
||
" <th>Charge</th>\n",
|
||
" <th>H_0_</th>\n",
|
||
" <th>O_0_</th>\n",
|
||
" <th>Ba</th>\n",
|
||
" <th>Cl</th>\n",
|
||
" <th>S_2_</th>\n",
|
||
" <th>S_6_</th>\n",
|
||
" <th>Sr</th>\n",
|
||
" <th>Barite</th>\n",
|
||
" <th>Celestite</th>\n",
|
||
" <th>Class</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>0</th>\n",
|
||
" <td>111.012434</td>\n",
|
||
" <td>55.508192</td>\n",
|
||
" <td>-7.779554e-09</td>\n",
|
||
" <td>2.697041e-26</td>\n",
|
||
" <td>2.210590e-15</td>\n",
|
||
" <td>2.041069e-02</td>\n",
|
||
" <td>4.082138e-02</td>\n",
|
||
" <td>0.000000e+00</td>\n",
|
||
" <td>0.000494</td>\n",
|
||
" <td>0.000494</td>\n",
|
||
" <td>0.001</td>\n",
|
||
" <td>1.000000</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1</th>\n",
|
||
" <td>111.012434</td>\n",
|
||
" <td>55.508427</td>\n",
|
||
" <td>-4.736083e-09</td>\n",
|
||
" <td>1.446346e-26</td>\n",
|
||
" <td>2.473481e-15</td>\n",
|
||
" <td>1.094567e-02</td>\n",
|
||
" <td>2.189133e-02</td>\n",
|
||
" <td>0.000000e+00</td>\n",
|
||
" <td>0.000553</td>\n",
|
||
" <td>0.000553</td>\n",
|
||
" <td>0.001</td>\n",
|
||
" <td>1.000000</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2</th>\n",
|
||
" <td>111.012434</td>\n",
|
||
" <td>55.508691</td>\n",
|
||
" <td>-1.311169e-09</td>\n",
|
||
" <td>3.889826e-28</td>\n",
|
||
" <td>2.769320e-15</td>\n",
|
||
" <td>2.943745e-04</td>\n",
|
||
" <td>5.887491e-04</td>\n",
|
||
" <td>0.000000e+00</td>\n",
|
||
" <td>0.000619</td>\n",
|
||
" <td>0.000619</td>\n",
|
||
" <td>0.001</td>\n",
|
||
" <td>1.000000</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>3</th>\n",
|
||
" <td>111.012434</td>\n",
|
||
" <td>55.508698</td>\n",
|
||
" <td>-1.220023e-09</td>\n",
|
||
" <td>1.442658e-29</td>\n",
|
||
" <td>2.777193e-15</td>\n",
|
||
" <td>1.091776e-05</td>\n",
|
||
" <td>2.183551e-05</td>\n",
|
||
" <td>0.000000e+00</td>\n",
|
||
" <td>0.000620</td>\n",
|
||
" <td>0.000620</td>\n",
|
||
" <td>0.001</td>\n",
|
||
" <td>1.000000</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>4</th>\n",
|
||
" <td>111.012434</td>\n",
|
||
" <td>55.508699</td>\n",
|
||
" <td>-1.216643e-09</td>\n",
|
||
" <td>5.350528e-31</td>\n",
|
||
" <td>2.777485e-15</td>\n",
|
||
" <td>4.049176e-07</td>\n",
|
||
" <td>8.098352e-07</td>\n",
|
||
" <td>0.000000e+00</td>\n",
|
||
" <td>0.000620</td>\n",
|
||
" <td>0.000620</td>\n",
|
||
" <td>0.001</td>\n",
|
||
" <td>1.000000</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2502495</th>\n",
|
||
" <td>111.012434</td>\n",
|
||
" <td>55.507488</td>\n",
|
||
" <td>3.573728e-09</td>\n",
|
||
" <td>5.424062e-145</td>\n",
|
||
" <td>1.375204e-10</td>\n",
|
||
" <td>9.953520e-07</td>\n",
|
||
" <td>2.266555e-03</td>\n",
|
||
" <td>5.509534e-149</td>\n",
|
||
" <td>0.000318</td>\n",
|
||
" <td>0.001450</td>\n",
|
||
" <td>0.001</td>\n",
|
||
" <td>1.000014</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2502496</th>\n",
|
||
" <td>111.012434</td>\n",
|
||
" <td>55.507501</td>\n",
|
||
" <td>3.494007e-09</td>\n",
|
||
" <td>2.011675e-146</td>\n",
|
||
" <td>1.377139e-10</td>\n",
|
||
" <td>9.817216e-07</td>\n",
|
||
" <td>2.217997e-03</td>\n",
|
||
" <td>2.043375e-150</td>\n",
|
||
" <td>0.000321</td>\n",
|
||
" <td>0.001429</td>\n",
|
||
" <td>0.001</td>\n",
|
||
" <td>1.000010</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2502497</th>\n",
|
||
" <td>111.012434</td>\n",
|
||
" <td>55.507512</td>\n",
|
||
" <td>3.429764e-09</td>\n",
|
||
" <td>7.460897e-148</td>\n",
|
||
" <td>1.377819e-10</td>\n",
|
||
" <td>9.706451e-07</td>\n",
|
||
" <td>2.179066e-03</td>\n",
|
||
" <td>7.578467e-152</td>\n",
|
||
" <td>0.000324</td>\n",
|
||
" <td>0.001412</td>\n",
|
||
" <td>0.001</td>\n",
|
||
" <td>1.000006</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2502498</th>\n",
|
||
" <td>111.012434</td>\n",
|
||
" <td>55.507520</td>\n",
|
||
" <td>3.381745e-09</td>\n",
|
||
" <td>2.767237e-149</td>\n",
|
||
" <td>1.371144e-10</td>\n",
|
||
" <td>9.621074e-07</td>\n",
|
||
" <td>2.149820e-03</td>\n",
|
||
" <td>2.810844e-153</td>\n",
|
||
" <td>0.000326</td>\n",
|
||
" <td>0.001400</td>\n",
|
||
" <td>0.001</td>\n",
|
||
" <td>1.000004</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2502499</th>\n",
|
||
" <td>111.012434</td>\n",
|
||
" <td>55.507525</td>\n",
|
||
" <td>3.348864e-09</td>\n",
|
||
" <td>5.321610e-151</td>\n",
|
||
" <td>1.376026e-10</td>\n",
|
||
" <td>9.564401e-07</td>\n",
|
||
" <td>2.129912e-03</td>\n",
|
||
" <td>5.405468e-155</td>\n",
|
||
" <td>0.000327</td>\n",
|
||
" <td>0.001391</td>\n",
|
||
" <td>0.001</td>\n",
|
||
" <td>1.000001</td>\n",
|
||
" <td>0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>2502500 rows × 13 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" H O Charge H_0_ O_0_ \\\n",
|
||
"0 111.012434 55.508192 -7.779554e-09 2.697041e-26 2.210590e-15 \n",
|
||
"1 111.012434 55.508427 -4.736083e-09 1.446346e-26 2.473481e-15 \n",
|
||
"2 111.012434 55.508691 -1.311169e-09 3.889826e-28 2.769320e-15 \n",
|
||
"3 111.012434 55.508698 -1.220023e-09 1.442658e-29 2.777193e-15 \n",
|
||
"4 111.012434 55.508699 -1.216643e-09 5.350528e-31 2.777485e-15 \n",
|
||
"... ... ... ... ... ... \n",
|
||
"2502495 111.012434 55.507488 3.573728e-09 5.424062e-145 1.375204e-10 \n",
|
||
"2502496 111.012434 55.507501 3.494007e-09 2.011675e-146 1.377139e-10 \n",
|
||
"2502497 111.012434 55.507512 3.429764e-09 7.460897e-148 1.377819e-10 \n",
|
||
"2502498 111.012434 55.507520 3.381745e-09 2.767237e-149 1.371144e-10 \n",
|
||
"2502499 111.012434 55.507525 3.348864e-09 5.321610e-151 1.376026e-10 \n",
|
||
"\n",
|
||
" Ba Cl S_2_ S_6_ Sr \\\n",
|
||
"0 2.041069e-02 4.082138e-02 0.000000e+00 0.000494 0.000494 \n",
|
||
"1 1.094567e-02 2.189133e-02 0.000000e+00 0.000553 0.000553 \n",
|
||
"2 2.943745e-04 5.887491e-04 0.000000e+00 0.000619 0.000619 \n",
|
||
"3 1.091776e-05 2.183551e-05 0.000000e+00 0.000620 0.000620 \n",
|
||
"4 4.049176e-07 8.098352e-07 0.000000e+00 0.000620 0.000620 \n",
|
||
"... ... ... ... ... ... \n",
|
||
"2502495 9.953520e-07 2.266555e-03 5.509534e-149 0.000318 0.001450 \n",
|
||
"2502496 9.817216e-07 2.217997e-03 2.043375e-150 0.000321 0.001429 \n",
|
||
"2502497 9.706451e-07 2.179066e-03 7.578467e-152 0.000324 0.001412 \n",
|
||
"2502498 9.621074e-07 2.149820e-03 2.810844e-153 0.000326 0.001400 \n",
|
||
"2502499 9.564401e-07 2.129912e-03 5.405468e-155 0.000327 0.001391 \n",
|
||
"\n",
|
||
" Barite Celestite Class \n",
|
||
"0 0.001 1.000000 0 \n",
|
||
"1 0.001 1.000000 0 \n",
|
||
"2 0.001 1.000000 0 \n",
|
||
"3 0.001 1.000000 0 \n",
|
||
"4 0.001 1.000000 0 \n",
|
||
"... ... ... ... \n",
|
||
"2502495 0.001 1.000014 0 \n",
|
||
"2502496 0.001 1.000010 0 \n",
|
||
"2502497 0.001 1.000006 0 \n",
|
||
"2502498 0.001 1.000004 0 \n",
|
||
"2502499 0.001 1.000001 0 \n",
|
||
"\n",
|
||
"[2502500 rows x 13 columns]"
|
||
]
|
||
},
|
||
"execution_count": 53,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"df_design"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Define Scaling and Normalization Functions"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 58,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def log_scale(df_design, df_result, func_dict):\n",
|
||
" \n",
|
||
" df_design = df_design.copy()\n",
|
||
" df_result = df_result.copy()\n",
|
||
" \n",
|
||
" for key in df_design.keys():\n",
|
||
" if key != \"Class\":\n",
|
||
" df_design[key] = np.vectorize(func_dict[key])(df_design[key])\n",
|
||
" df_result[key] = np.vectorize(func_dict[key])(df_result[key])\n",
|
||
" \n",
|
||
" return df_result, df_design\n",
|
||
"\n",
|
||
"# Get minimum and maximum values for each column\n",
|
||
"def get_min_max(df_design, df_result):\n",
|
||
" \n",
|
||
" min_vals_des = df_design.min()\n",
|
||
" max_vals_des = df_design.max()\n",
|
||
" \n",
|
||
" min_vals_res = df_result.min()\n",
|
||
" max_vals_res = df_result.max()\n",
|
||
"\n",
|
||
" # minimum of input and output data to get global minimum/maximum\n",
|
||
" data_min = np.minimum(min_vals_des, min_vals_res).to_dict()\n",
|
||
" data_max = np.maximum(max_vals_des, max_vals_res).to_dict()\n",
|
||
"\n",
|
||
" return data_min, data_max\n",
|
||
"\n",
|
||
"\n",
|
||
"df_design_log, df_results_log = log_scale(df_design, df_results, func_dict_in)\n",
|
||
"data_min_log, data_max_log = get_min_max(df_design_log, df_results_log)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 61,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def preprocess(data, func_dict, data_min, data_max):\n",
|
||
" data = data.copy()\n",
|
||
" for key in data.keys():\n",
|
||
" if key != \"Class\":\n",
|
||
" data[key] = (data[key] - data_min[key]) / (data_max[key] - data_min[key])\n",
|
||
"\n",
|
||
" return data\n",
|
||
"\n",
|
||
"def postprocess(data, func_dict, data_min, data_max):\n",
|
||
" data = data.copy()\n",
|
||
" for key in data.keys():\n",
|
||
" if key != \"Class\":\n",
|
||
" data[key] = data[key] * (data_max[key] - data_min[key]) + data_min[key]\n",
|
||
" data[key] = np.vectorize(func_dict[key])(data[key])\n",
|
||
" return data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Preprocess the data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 62,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"pp_design = preprocess(df_design_log, func_dict_in, data_min_log, data_max_log)\n",
|
||
"pp_results = preprocess(df_results_log, func_dict_in, data_min_log, data_max_log)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Sample the data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"X_train, X_test, y_train, y_test = sk.train_test_split(pp_design, pp_results, test_size = 0.2)\n",
|
||
"X_train, X_val, y_train, y_val = sk.train_test_split(X_train, y_train, test_size = 0.1)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Custom Loss function"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def custom_loss_H20(df_design_log, df_result_log, data_min_log, data_max_log, func_dict_out, postprocess):\n",
|
||
" df_result = postprocess(df_result_log, func_dict_out, data_min_log, data_max_log) \n",
|
||
" return keras.losses.Huber + np.sum(((df_result['H'] / df_result['O']) - 2)**2)\n",
|
||
"\n",
|
||
"def loss_wrapper(data_min_log, data_max_log, func_dict_out, postprocess):\n",
|
||
" def loss(df_design_log, df_result_log):\n",
|
||
" return custom_loss_H20(df_design_log, df_result_log, data_min_log, data_max_log, func_dict_out, postprocess)\n",
|
||
" return loss"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Train the model"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 1/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 0.0015 - val_loss: 1.2993e-06\n",
|
||
"Epoch 2/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 1.3182e-06 - val_loss: 1.1714e-06\n",
|
||
"Epoch 3/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 1.4322e-06 - val_loss: 1.4424e-06\n",
|
||
"Epoch 4/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 1.1811e-06 - val_loss: 1.1027e-06\n",
|
||
"Epoch 5/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 1.0509e-06 - val_loss: 1.1202e-06\n",
|
||
"Epoch 6/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 7.9101e-07 - val_loss: 1.0344e-06\n",
|
||
"Epoch 7/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 8.5978e-07 - val_loss: 1.0202e-06\n",
|
||
"Epoch 8/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 7.6363e-07 - val_loss: 1.5508e-06\n",
|
||
"Epoch 9/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 8.2612e-07 - val_loss: 1.0281e-06\n",
|
||
"Epoch 10/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 7.8237e-07 - val_loss: 9.6918e-07\n",
|
||
"Epoch 11/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 7.8727e-07 - val_loss: 9.8902e-07\n",
|
||
"Epoch 12/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 7.2731e-07 - val_loss: 9.4628e-07\n",
|
||
"Epoch 13/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 6.2018e-07 - val_loss: 1.0144e-06\n",
|
||
"Epoch 14/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 6.0086e-07 - val_loss: 9.9860e-07\n",
|
||
"Epoch 15/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 6.6483e-07 - val_loss: 9.5001e-07\n",
|
||
"Epoch 16/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 6.8847e-07 - val_loss: 9.4421e-07\n",
|
||
"Epoch 17/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 6.6030e-07 - val_loss: 9.3255e-07\n",
|
||
"Epoch 18/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 6.4765e-07 - val_loss: 9.2782e-07\n",
|
||
"Epoch 19/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 7.0107e-07 - val_loss: 9.2918e-07\n",
|
||
"Epoch 20/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 5.7916e-07 - val_loss: 9.3070e-07\n",
|
||
"Epoch 21/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 7.1965e-07 - val_loss: 9.3583e-07\n",
|
||
"Epoch 22/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 6.1729e-07 - val_loss: 9.2800e-07\n",
|
||
"Epoch 23/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 5.8376e-07 - val_loss: 9.2606e-07\n",
|
||
"Epoch 24/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 7.1949e-07 - val_loss: 9.2550e-07\n",
|
||
"Epoch 25/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 7.0228e-07 - val_loss: 9.2386e-07\n",
|
||
"Epoch 26/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 6.4762e-07 - val_loss: 9.2222e-07\n",
|
||
"Epoch 27/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 6.3545e-07 - val_loss: 9.2336e-07\n",
|
||
"Epoch 28/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 6.1678e-07 - val_loss: 9.2510e-07\n",
|
||
"Epoch 29/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 7.2552e-07 - val_loss: 9.2267e-07\n",
|
||
"Epoch 30/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 6.7044e-07 - val_loss: 9.2244e-07\n",
|
||
"Epoch 31/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 6.4412e-07 - val_loss: 9.2193e-07\n",
|
||
"Epoch 32/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 7.9198e-07 - val_loss: 9.2181e-07\n",
|
||
"Epoch 33/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 8.8825e-07 - val_loss: 9.2173e-07\n",
|
||
"Epoch 34/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 6.1502e-07 - val_loss: 9.2309e-07\n",
|
||
"Epoch 35/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 6.5551e-07 - val_loss: 9.2157e-07\n",
|
||
"Epoch 36/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 6.3050e-07 - val_loss: 9.2172e-07\n",
|
||
"Epoch 37/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 6.8292e-07 - val_loss: 9.2127e-07\n",
|
||
"Epoch 38/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 5.7185e-07 - val_loss: 9.2111e-07\n",
|
||
"Epoch 39/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 6.1807e-07 - val_loss: 9.2119e-07\n",
|
||
"Epoch 40/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 5.7785e-07 - val_loss: 9.2112e-07\n",
|
||
"Epoch 41/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 6.6563e-07 - val_loss: 9.2108e-07\n",
|
||
"Epoch 42/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 7.1370e-07 - val_loss: 9.2109e-07\n",
|
||
"Epoch 43/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 7.2470e-07 - val_loss: 9.2105e-07\n",
|
||
"Epoch 44/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 7.2408e-07 - val_loss: 9.2102e-07\n",
|
||
"Epoch 45/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 2ms/step - loss: 6.6530e-07 - val_loss: 9.2098e-07\n",
|
||
"Epoch 46/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 6.7502e-07 - val_loss: 9.2098e-07\n",
|
||
"Epoch 47/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 6.3829e-07 - val_loss: 9.2094e-07\n",
|
||
"Epoch 48/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 5.8739e-07 - val_loss: 9.2096e-07\n",
|
||
"Epoch 49/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 7.0502e-07 - val_loss: 9.2095e-07\n",
|
||
"Epoch 50/50\n",
|
||
"\u001b[1m3520/3520\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 2ms/step - loss: 6.5994e-07 - val_loss: 9.2094e-07\n",
|
||
"Training took 317.1207675933838 seconds\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# measure time\n",
|
||
"start = time.time()\n",
|
||
"\n",
|
||
"history = model.fit(X_train, \n",
|
||
" y_train, \n",
|
||
" batch_size = batch_size, \n",
|
||
" epochs = epochs, \n",
|
||
" validation_data = (X_val, y_val)\n",
|
||
")\n",
|
||
"\n",
|
||
"end = time.time()\n",
|
||
"\n",
|
||
"print(\"Training took {} seconds\".format(end - start))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"plt.plot(history.history[\"loss\"], \"o-\", label = \"Training Loss\")\n",
|
||
"plt.xlabel(\"Epoch\")\n",
|
||
"# plt.yscale('log')\n",
|
||
"plt.ylabel(\"Loss (Huber)\")\n",
|
||
"plt.grid('on')\n",
|
||
"\n",
|
||
"plt.savefig(\"loss_all.png\", dpi=300)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"plt.plot(history.history[\"loss\"][1:], \"o-\", label = \"Training Loss\")\n",
|
||
"plt.xlabel(\"Epoch\")\n",
|
||
"# plt.yscale('log')\n",
|
||
"plt.ylabel(\"Loss (Huber)\")\n",
|
||
"plt.grid('on')\n",
|
||
"plt.savefig(\"loss_1_to_end.png\", dpi=300)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Test the model"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\u001b[1m15641/15641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 337us/step - loss: 7.0854e-07\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"6.561523377968115e-07"
|
||
]
|
||
},
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"model.evaluate(X_test, y_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Save the model"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 53,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Save the model\n",
|
||
"model.save(\"Barite_50_Model_additional_species.keras\")"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "training",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.11.11"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|