{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## General Information" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook is used to train a simple neural network model to predict the chemistry in the barite benchmark (50x50 grid). The training data is stored in the repository using **git large file storage** and can be downloaded after the installation of git lfs using the `git lfs pull` command.\n", "\n", "It is then recommended to create a Python environment using miniconda. The necessary dependencies are contained in `environment.yml` and can be installed using `conda env create -f environment.yml`.\n", "\n", "The data set is divided into a design and result part and consists of the iterations of a reference simulation. The design part of the data set contains the chemical concentrations at time $t$ and the result part at time $t+1$, which are to be learned by the model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup Libraries" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import keras\n", "import h5py\n", "import numpy as np\n", "import pandas as pd\n", "import time\n", "import sklearn.model_selection as sk\n", "import matplotlib.pyplot as plt\n", "from sklearn.cluster import KMeans\n", "from imblearn.over_sampling import SMOTE\n", "from imblearn.under_sampling import RandomUnderSampler\n", "from imblearn.over_sampling import RandomOverSampler\n", "from collections import Counter\n", "import os\n", "from preprocessing import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define parameters" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "dtype = \"float32\"\n", "activation = \"relu\"\n", "\n", "lr = 0.001\n", "batch_size = 512\n", "epochs = 50 # default 400 epochs\n", "\n", "lr_schedule = keras.optimizers.schedules.ExponentialDecay(\n", " initial_learning_rate=lr,\n", " decay_steps=2000,\n", " decay_rate=0.9,\n", " staircase=True\n", ")\n", "\n", "optimizer_simple = keras.optimizers.Adam(learning_rate=lr_schedule)\n", "optimizer_large = keras.optimizers.Adam(learning_rate=lr_schedule)\n", "\n", "loss = keras.losses.Huber()\n", "\n", "sample_fraction = 0.8" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup the model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Model: \"sequential\"\n",
       "
\n" ], "text/plain": [ "\u001b[1mModel: \"sequential\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
       "│ dense (Dense)                   │ (None, 128)            │         1,664 │\n",
       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
       "│ dense_1 (Dense)                 │ (None, 128)            │        16,512 │\n",
       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
       "│ dense_2 (Dense)                 │ (None, 12)             │         1,548 │\n",
       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
       "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m1,664\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m16,512\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense_2 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m1,548\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Total params: 19,724 (77.05 KB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m19,724\u001b[0m (77.05 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Trainable params: 19,724 (77.05 KB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m19,724\u001b[0m (77.05 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Non-trainable params: 0 (0.00 B)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model_simple = keras.Sequential(\n", " [\n", " keras.Input(shape = (12,), dtype = \"float32\"),\n", " keras.layers.Dense(units = 128, activation = \"relu\", dtype = \"float32\"),\n", " keras.layers.Dense(units = 128, activation = \"relu\", dtype = \"float32\"),\n", " keras.layers.Dense(units = 12, dtype = \"float32\")\n", " ]\n", ")\n", "\n", "model_simple.compile(optimizer=optimizer_simple, loss = loss)\n", "model_simple.summary()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Model: \"sequential_1\"\n",
       "
\n" ], "text/plain": [ "\u001b[1mModel: \"sequential_1\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
       "│ dense_3 (Dense)                 │ (None, 512)            │         6,656 │\n",
       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
       "│ dense_4 (Dense)                 │ (None, 1024)           │       525,312 │\n",
       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
       "│ dense_5 (Dense)                 │ (None, 512)            │       524,800 │\n",
       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
       "│ dense_6 (Dense)                 │ (None, 12)             │         6,156 │\n",
       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
       "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ dense_3 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m6,656\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense_4 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1024\u001b[0m) │ \u001b[38;5;34m525,312\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense_5 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m524,800\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense_6 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m6,156\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Total params: 1,062,924 (4.05 MB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,062,924\u001b[0m (4.05 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Trainable params: 1,062,924 (4.05 MB)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,062,924\u001b[0m (4.05 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
 Non-trainable params: 0 (0.00 B)\n",
       "
\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model_large = keras.Sequential(\n", " [keras.layers.Input(shape=(12,), dtype=dtype),\n", " keras.layers.Dense(512, activation='relu', dtype=dtype),\n", " keras.layers.Dense(1024, activation='relu', dtype=dtype),\n", " keras.layers.Dense(512, activation='relu', dtype=dtype),\n", " keras.layers.Dense(12, dtype=dtype)\n", " ])\n", "\n", "model_large.compile(optimizer=optimizer_large, loss = loss)\n", "model_large.summary()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define some functions and helper classes" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def Safelog(val):\n", " # get range of vector\n", " if val > 0:\n", " return np.log10(val)\n", " elif val < 0:\n", " return -np.log10(-val)\n", " else:\n", " return 0\n", "\n", "def Safeexp(val):\n", " if val > 0:\n", " return -10 ** -val\n", " elif val < 0:\n", " return 10 ** val\n", " else:\n", " return 0" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# ? Why does the charge is using another logarithm than the other species\n", "\n", "func_dict_in = {\n", " \"H\" : np.log1p,\n", " \"O\" : np.log1p,\n", " \"Charge\" : Safelog,\n", " \"H_0_\" : np.log1p,\n", " \"O_0_\" : np.log1p,\n", " \"Ba\" : np.log1p,\n", " \"Cl\" : np.log1p,\n", " \"S_2_\" : np.log1p,\n", " \"S_6_\" : np.log1p,\n", " \"Sr\" : np.log1p,\n", " \"Barite\" : np.log1p,\n", " \"Celestite\" : np.log1p,\n", "}\n", "\n", "func_dict_out = {\n", " \"H\" : np.expm1,\n", " \"O\" : np.expm1,\n", " \"Charge\" : Safeexp,\n", " \"H_0_\" : np.expm1,\n", " \"O_0_\" : np.expm1,\n", " \"Ba\" : np.expm1,\n", " \"Cl\" : np.expm1,\n", " \"S_2_\" : np.expm1,\n", " \"S_6_\" : np.expm1,\n", " \"Sr\" : np.expm1,\n", " \"Barite\" : np.expm1,\n", " \"Celestite\" : np.expm1,\n", "}\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Read data from `.h5` file and convert it to a `pandas.DataFrame`" ] }, { "cell_type": "code", "execution_count": 99, "metadata": {}, "outputs": [], "source": [ "# os.chdir('/mnt/beegfs/home/signer/projects/model-training')\n", "data_file = h5py.File(\"Barite_50_Data_training.h5\")\n", "\n", "design = data_file[\"design\"]\n", "results = data_file[\"result\"]\n", "\n", "df_design = pd.DataFrame(np.array(design[\"data\"]).transpose(), columns = design[\"names\"].asstr())\n", "df_results = pd.DataFrame(np.array(results[\"data\"]).transpose(), columns = results[\"names\"].asstr())\n", "\n", "data_file.close()" ] }, { "cell_type": "code", "execution_count": 100, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
HOChargeH_0_O_0_BaClS_2_S_6_SrBariteCelestite
0111.01243455.508192-7.779554e-092.697041e-262.210590e-152.041069e-024.082138e-020.000000e+000.0004940.0004940.0011.000000
1111.01243455.508427-4.736083e-091.446346e-262.473481e-151.094567e-022.189133e-020.000000e+000.0005530.0005530.0011.000000
2111.01243455.508691-1.311169e-093.889826e-282.769320e-152.943745e-045.887491e-040.000000e+000.0006190.0006190.0011.000000
3111.01243455.508698-1.220023e-091.442658e-292.777193e-151.091776e-052.183551e-050.000000e+000.0006200.0006200.0011.000000
4111.01243455.508699-1.216643e-095.350528e-312.777485e-154.049176e-078.098352e-070.000000e+000.0006200.0006200.0011.000000
.......................................
2502495111.01243455.5074883.573728e-095.424062e-1451.375204e-109.953520e-072.266555e-035.509534e-1490.0003180.0014500.0011.000014
2502496111.01243455.5075013.494007e-092.011675e-1461.377139e-109.817216e-072.217997e-032.043375e-1500.0003210.0014290.0011.000010
2502497111.01243455.5075123.429764e-097.460897e-1481.377819e-109.706451e-072.179066e-037.578467e-1520.0003240.0014120.0011.000006
2502498111.01243455.5075203.381745e-092.767237e-1491.371144e-109.621074e-072.149820e-032.810844e-1530.0003260.0014000.0011.000004
2502499111.01243455.5075253.348864e-095.321610e-1511.376026e-109.564401e-072.129912e-035.405468e-1550.0003270.0013910.0011.000001
\n", "

2502500 rows × 12 columns

\n", "
" ], "text/plain": [ " H O Charge H_0_ O_0_ \\\n", "0 111.012434 55.508192 -7.779554e-09 2.697041e-26 2.210590e-15 \n", "1 111.012434 55.508427 -4.736083e-09 1.446346e-26 2.473481e-15 \n", "2 111.012434 55.508691 -1.311169e-09 3.889826e-28 2.769320e-15 \n", "3 111.012434 55.508698 -1.220023e-09 1.442658e-29 2.777193e-15 \n", "4 111.012434 55.508699 -1.216643e-09 5.350528e-31 2.777485e-15 \n", "... ... ... ... ... ... \n", "2502495 111.012434 55.507488 3.573728e-09 5.424062e-145 1.375204e-10 \n", "2502496 111.012434 55.507501 3.494007e-09 2.011675e-146 1.377139e-10 \n", "2502497 111.012434 55.507512 3.429764e-09 7.460897e-148 1.377819e-10 \n", "2502498 111.012434 55.507520 3.381745e-09 2.767237e-149 1.371144e-10 \n", "2502499 111.012434 55.507525 3.348864e-09 5.321610e-151 1.376026e-10 \n", "\n", " Ba Cl S_2_ S_6_ Sr \\\n", "0 2.041069e-02 4.082138e-02 0.000000e+00 0.000494 0.000494 \n", "1 1.094567e-02 2.189133e-02 0.000000e+00 0.000553 0.000553 \n", "2 2.943745e-04 5.887491e-04 0.000000e+00 0.000619 0.000619 \n", "3 1.091776e-05 2.183551e-05 0.000000e+00 0.000620 0.000620 \n", "4 4.049176e-07 8.098352e-07 0.000000e+00 0.000620 0.000620 \n", "... ... ... ... ... ... \n", "2502495 9.953520e-07 2.266555e-03 5.509534e-149 0.000318 0.001450 \n", "2502496 9.817216e-07 2.217997e-03 2.043375e-150 0.000321 0.001429 \n", "2502497 9.706451e-07 2.179066e-03 7.578467e-152 0.000324 0.001412 \n", "2502498 9.621074e-07 2.149820e-03 2.810844e-153 0.000326 0.001400 \n", "2502499 9.564401e-07 2.129912e-03 5.405468e-155 0.000327 0.001391 \n", "\n", " Barite Celestite \n", "0 0.001 1.000000 \n", "1 0.001 1.000000 \n", "2 0.001 1.000000 \n", "3 0.001 1.000000 \n", "4 0.001 1.000000 \n", "... ... ... \n", "2502495 0.001 1.000014 \n", "2502496 0.001 1.000010 \n", "2502497 0.001 1.000006 \n", "2502498 0.001 1.000004 \n", "2502499 0.001 1.000001 \n", "\n", "[2502500 rows x 12 columns]" ] }, "execution_count": 100, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_design" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Classify each cell with kmeans" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/hannessigner/miniconda3/envs/ai/lib/python3.11/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (2). Possibly due to duplicate points in X.\n", " return fit_method(estimator, *args, **kwargs)\n" ] } ], "source": [ "# widget with slider for the index\n", "\n", "class_label_design = np.array([])\n", "class_label_result = np.array([])\n", "\n", "\n", "i = len(df_design) / 2500\n", "for i in range(0,252):\n", " field_design = np.array(df_design['Barite'][(i*2500):(i*2500+2500)]).reshape(50,50)\n", " field_result = np.array(df_results['Barite'][(i*2500):(i*2500+2500)]).reshape(50,50)\n", " \n", " kmeans_design = KMeans(n_clusters=2, random_state=0).fit(field_design.reshape(-1,1))\n", " kmeans_result = KMeans(n_clusters=2, random_state=0).fit(field_result.reshape(-1,1))\n", " \n", " class_label_design = np.append(class_label_design.astype(int), kmeans_design.labels_)\n", " class_label_result = np.append(class_label_result.astype(int), kmeans_result.labels_)\n", " \n", "\n", "\n", "class_label_design = pd.DataFrame(class_label_design, columns = [\"Class\"])\n", "class_label_result = pd.DataFrame(class_label_result, columns = [\"Class\"])\n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "if(\"Class\" in df_design.columns and \"Class\" in df_results.columns):\n", " print(\"Class column already exists\")\n", "else:\n", " df_design = pd.concat([df_design, class_label_design], axis=1)\n", " df_results = pd.concat([df_results, class_label_design], axis=1)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Amount class 0: 0.9520126984126984\n", "Amount class 1: 0.047987301587301585\n" ] } ], "source": [ "counter = Counter(df_design.iloc[:,-1])\n", "print(\"Amount class 0:\", counter[0] / (counter[0] + counter[1]) )\n", "print(\"Amount class 1:\", counter[1] / (counter[0] + counter[1]) )\n" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "i=251\n", "\n", "plt.imshow(np.array(df_results['Barite'][(i*2500):(i*2500+2500)]).reshape(50,50), interpolation='bicubic', origin='lower')\n", "plt.contour(np.array(df_results['Class'][(i*2500):(i*2500+2500)]).reshape(50,50), levels=[0.1], colors='red', origin='lower')" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "df_design['Class threshold'] = df_design['Barite'] > 0.49\n" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.9991298042059463" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_design['Class'][df_design[\"Class threshold\"] == True].sum() / df_design['Class threshold'][df_design[\"Class threshold\"] == True].sum()" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "i = 251\n", "plt.imshow(np.array(df_design['Class threshold'][(i*2500):(i*2500+2500)]).reshape(50,50), interpolation='bicubic', origin='lower')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Split into Training and Testing datsets" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = sk.train_test_split(df_design, df_results, test_size = 0.2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Perform Over and Under Sampling on dataset to balance classes" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "def balancer(design, target, strategy, sample_fraction=0.5):\n", " counter = Counter(design.iloc[:,-1])\n", " print(\"Amount class 0 before:\", counter[0] / (counter[0] + counter[1]) )\n", " print(\"Amount class 1 before:\", counter[1] / (counter[0] + counter[1]) )\n", " \n", " number_features = (df_design.columns != \"Class\").sum()\n", " if(\"Class\" not in design.columns):\n", " if(\"Class\" in target.columns):\n", " classes = target['Class']\n", " else:\n", " raise(\"No class column found\")\n", " else:\n", " classes = design['Class']\n", " df = pd.concat([design.loc[:,design.columns != \"Class\"], target.loc[:, design.columns != \"Class\"], classes], axis=1)\n", " \n", " if strategy == 'smote':\n", " print(\"Using SMOTE strategy\")\n", " smote = SMOTE(sampling_strategy=sample_fraction)\n", " df_resampled, classes_resampled = smote.fit_resample(df.loc[:, df.columns != \"Class\"], df.loc[:, df.columns == \"Class\"])\n", " \n", " elif strategy == 'over':\n", " print(\"Using Oversampling\")\n", " over = RandomOverSampler()\n", " df_resampled, classes_resampled = over.fit_resample(df.loc[:, df.columns != \"Class\"], df.loc[:, df.columns == \"Class\"])\n", " \n", " elif strategy == 'under':\n", " print(\"Using Undersampling\")\n", " under = RandomUnderSampler()\n", " df_resampled, classes_resampled = under.fit_resample(df.loc[:, df.columns != \"Class\"], df.loc[:, df.columns == \"Class\"])\n", "\n", " counter = Counter(classes_resampled[\"Class\"])\n", " print(\"Amount class 0 after:\", counter[0] / (counter[0] + counter[1]) )\n", " print(\"Amount class 1 after:\", counter[1] / (counter[0] + counter[1]) )\n", " \n", " design_resampled = pd.concat([df_resampled.iloc[:,0:number_features], classes_resampled], axis=1)\n", " target_resampled = pd.concat([df_resampled.iloc[:,number_features:], classes_resampled], axis=1)\n", " \n", " return design_resampled, target_resampled " ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Amount class 0 before: 0.9563730158730158\n", "Amount class 1 before: 0.043626984126984125\n" ] }, { "ename": "IndexError", "evalue": "Boolean index has wrong length: 11 instead of 10", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[32], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m X_train, y_train \u001b[38;5;241m=\u001b[39m \u001b[43mbalancer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mover\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n", "Cell \u001b[0;32mIn[31], line 14\u001b[0m, in \u001b[0;36mbalancer\u001b[0;34m(design, target, strategy, sample_fraction)\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 13\u001b[0m classes \u001b[38;5;241m=\u001b[39m design[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mClass\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[0;32m---> 14\u001b[0m df \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mconcat([design\u001b[38;5;241m.\u001b[39mloc[:,design\u001b[38;5;241m.\u001b[39mcolumns \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mClass\u001b[39m\u001b[38;5;124m\"\u001b[39m], \u001b[43mtarget\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloc\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdesign\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcolumns\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m!=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mClass\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m, classes], axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m strategy \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msmote\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUsing SMOTE strategy\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", "File \u001b[0;32m~/miniconda3/envs/ai/lib/python3.11/site-packages/pandas/core/indexing.py:1184\u001b[0m, in \u001b[0;36m_LocationIndexer.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 1182\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_is_scalar_access(key):\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mobj\u001b[38;5;241m.\u001b[39m_get_value(\u001b[38;5;241m*\u001b[39mkey, takeable\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_takeable)\n\u001b[0;32m-> 1184\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_getitem_tuple\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1185\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# we by definition only have the 0th axis\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m axis \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maxis \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;241m0\u001b[39m\n", "File \u001b[0;32m~/miniconda3/envs/ai/lib/python3.11/site-packages/pandas/core/indexing.py:1377\u001b[0m, in \u001b[0;36m_LocIndexer._getitem_tuple\u001b[0;34m(self, tup)\u001b[0m\n\u001b[1;32m 1374\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_multi_take_opportunity(tup):\n\u001b[1;32m 1375\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_multi_take(tup)\n\u001b[0;32m-> 1377\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_getitem_tuple_same_dim\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtup\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/miniconda3/envs/ai/lib/python3.11/site-packages/pandas/core/indexing.py:1020\u001b[0m, in \u001b[0;36m_LocationIndexer._getitem_tuple_same_dim\u001b[0;34m(self, tup)\u001b[0m\n\u001b[1;32m 1017\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m com\u001b[38;5;241m.\u001b[39mis_null_slice(key):\n\u001b[1;32m 1018\u001b[0m \u001b[38;5;28;01mcontinue\u001b[39;00m\n\u001b[0;32m-> 1020\u001b[0m retval \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mretval\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_getitem_axis\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mi\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1021\u001b[0m \u001b[38;5;66;03m# We should never have retval.ndim < self.ndim, as that should\u001b[39;00m\n\u001b[1;32m 1022\u001b[0m \u001b[38;5;66;03m# be handled by the _getitem_lowerdim call above.\u001b[39;00m\n\u001b[1;32m 1023\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m retval\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mndim\n", "File \u001b[0;32m~/miniconda3/envs/ai/lib/python3.11/site-packages/pandas/core/indexing.py:1413\u001b[0m, in \u001b[0;36m_LocIndexer._getitem_axis\u001b[0;34m(self, key, axis)\u001b[0m\n\u001b[1;32m 1411\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_slice_axis(key, axis\u001b[38;5;241m=\u001b[39maxis)\n\u001b[1;32m 1412\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m com\u001b[38;5;241m.\u001b[39mis_bool_indexer(key):\n\u001b[0;32m-> 1413\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_getbool_axis\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maxis\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1414\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m is_list_like_indexer(key):\n\u001b[1;32m 1415\u001b[0m \u001b[38;5;66;03m# an iterable multi-selection\u001b[39;00m\n\u001b[1;32m 1416\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28misinstance\u001b[39m(key, \u001b[38;5;28mtuple\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(labels, MultiIndex)):\n", "File \u001b[0;32m~/miniconda3/envs/ai/lib/python3.11/site-packages/pandas/core/indexing.py:1209\u001b[0m, in \u001b[0;36m_LocationIndexer._getbool_axis\u001b[0;34m(self, key, axis)\u001b[0m\n\u001b[1;32m 1205\u001b[0m \u001b[38;5;129m@final\u001b[39m\n\u001b[1;32m 1206\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_getbool_axis\u001b[39m(\u001b[38;5;28mself\u001b[39m, key, axis: AxisInt):\n\u001b[1;32m 1207\u001b[0m \u001b[38;5;66;03m# caller is responsible for ensuring non-None axis\u001b[39;00m\n\u001b[1;32m 1208\u001b[0m labels \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mobj\u001b[38;5;241m.\u001b[39m_get_axis(axis)\n\u001b[0;32m-> 1209\u001b[0m key \u001b[38;5;241m=\u001b[39m \u001b[43mcheck_bool_indexer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1210\u001b[0m inds \u001b[38;5;241m=\u001b[39m key\u001b[38;5;241m.\u001b[39mnonzero()[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1211\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mobj\u001b[38;5;241m.\u001b[39m_take_with_is_copy(inds, axis\u001b[38;5;241m=\u001b[39maxis)\n", "File \u001b[0;32m~/miniconda3/envs/ai/lib/python3.11/site-packages/pandas/core/indexing.py:2681\u001b[0m, in \u001b[0;36mcheck_bool_indexer\u001b[0;34m(index, key)\u001b[0m\n\u001b[1;32m 2677\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_array_like(result):\n\u001b[1;32m 2678\u001b[0m \u001b[38;5;66;03m# GH 33924\u001b[39;00m\n\u001b[1;32m 2679\u001b[0m \u001b[38;5;66;03m# key may contain nan elements, check_array_indexer needs bool array\u001b[39;00m\n\u001b[1;32m 2680\u001b[0m result \u001b[38;5;241m=\u001b[39m pd_array(result, dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mbool\u001b[39m)\n\u001b[0;32m-> 2681\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcheck_array_indexer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresult\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/miniconda3/envs/ai/lib/python3.11/site-packages/pandas/core/indexers/utils.py:539\u001b[0m, in \u001b[0;36mcheck_array_indexer\u001b[0;34m(array, indexer)\u001b[0m\n\u001b[1;32m 537\u001b[0m \u001b[38;5;66;03m# GH26658\u001b[39;00m\n\u001b[1;32m 538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(indexer) \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(array):\n\u001b[0;32m--> 539\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m(\n\u001b[1;32m 540\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBoolean index has wrong length: \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 541\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(indexer)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m instead of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(array)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 542\u001b[0m )\n\u001b[1;32m 543\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m is_integer_dtype(dtype):\n\u001b[1;32m 544\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n", "\u001b[0;31mIndexError\u001b[0m: Boolean index has wrong length: 11 instead of 10" ] } ], "source": [ "X_train, y_train = balancer(X_train, y_train, 'over')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define Scaling and Normalization Functions" ] }, { "cell_type": "code", "execution_count": 88, "metadata": {}, "outputs": [], "source": [ "def log_scale(df_design, df_result, func_dict):\n", " \n", " df_design = df_design.copy()\n", " df_result = df_result.copy()\n", " \n", " for key in df_design.keys():\n", " if key != \"Class\":\n", " df_design[key] = np.vectorize(func_dict[key])(df_design[key])\n", " df_result[key] = np.vectorize(func_dict[key])(df_result[key])\n", " \n", " return df_design, df_result\n", "\n", "# Get minimum and maximum values for each column\n", "def get_min_max(df_design, df_result):\n", " \n", " min_vals_des = df_design.min()\n", " max_vals_des = df_design.max()\n", " \n", " min_vals_res = df_result.min()\n", " max_vals_res = df_result.max()\n", "\n", " # minimum of input and output data to get global minimum/maximum\n", " data_min = np.minimum(min_vals_des, min_vals_res).to_dict()\n", " data_max = np.maximum(max_vals_des, max_vals_res).to_dict()\n", "\n", " return data_min, data_max\n" ] }, { "cell_type": "code", "execution_count": 101, "metadata": {}, "outputs": [], "source": [ "df_design_log, df_results_log = log_scale(df_design, df_results, func_dict_in)\n", "data_min_log, data_max_log = get_min_max(df_design_log, df_results_log)" ] }, { "cell_type": "code", "execution_count": 102, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
HOChargeH_0_O_0_BaClS_2_S_6_SrBariteCelestite
0111.01243455.508192-7.779554e-092.697041e-262.210590e-152.041069e-024.082138e-020.000000e+000.0004940.0004940.0011.000000
1111.01243455.508427-4.736083e-091.446346e-262.473481e-151.094567e-022.189133e-020.000000e+000.0005530.0005530.0011.000000
2111.01243455.508691-1.311169e-093.889826e-282.769320e-152.943745e-045.887491e-040.000000e+000.0006190.0006190.0011.000000
3111.01243455.508698-1.220023e-091.442658e-292.777193e-151.091776e-052.183551e-050.000000e+000.0006200.0006200.0011.000000
4111.01243455.508699-1.216643e-095.350528e-312.777485e-154.049176e-078.098352e-070.000000e+000.0006200.0006200.0011.000000
.......................................
2502495111.01243455.5074883.573728e-095.424062e-1451.375204e-109.953520e-072.266555e-035.509534e-1490.0003180.0014500.0011.000014
2502496111.01243455.5075013.494007e-092.011675e-1461.377139e-109.817216e-072.217997e-032.043375e-1500.0003210.0014290.0011.000010
2502497111.01243455.5075123.429764e-097.460897e-1481.377819e-109.706451e-072.179066e-037.578467e-1520.0003240.0014120.0011.000006
2502498111.01243455.5075203.381745e-092.767237e-1491.371144e-109.621074e-072.149820e-032.810844e-1530.0003260.0014000.0011.000004
2502499111.01243455.5075253.348864e-095.321610e-1511.376026e-109.564401e-072.129912e-035.405468e-1550.0003270.0013910.0011.000001
\n", "

2502500 rows × 12 columns

\n", "
" ], "text/plain": [ " H O Charge H_0_ O_0_ \\\n", "0 111.012434 55.508192 -7.779554e-09 2.697041e-26 2.210590e-15 \n", "1 111.012434 55.508427 -4.736083e-09 1.446346e-26 2.473481e-15 \n", "2 111.012434 55.508691 -1.311169e-09 3.889826e-28 2.769320e-15 \n", "3 111.012434 55.508698 -1.220023e-09 1.442658e-29 2.777193e-15 \n", "4 111.012434 55.508699 -1.216643e-09 5.350528e-31 2.777485e-15 \n", "... ... ... ... ... ... \n", "2502495 111.012434 55.507488 3.573728e-09 5.424062e-145 1.375204e-10 \n", "2502496 111.012434 55.507501 3.494007e-09 2.011675e-146 1.377139e-10 \n", "2502497 111.012434 55.507512 3.429764e-09 7.460897e-148 1.377819e-10 \n", "2502498 111.012434 55.507520 3.381745e-09 2.767237e-149 1.371144e-10 \n", "2502499 111.012434 55.507525 3.348864e-09 5.321610e-151 1.376026e-10 \n", "\n", " Ba Cl S_2_ S_6_ Sr \\\n", "0 2.041069e-02 4.082138e-02 0.000000e+00 0.000494 0.000494 \n", "1 1.094567e-02 2.189133e-02 0.000000e+00 0.000553 0.000553 \n", "2 2.943745e-04 5.887491e-04 0.000000e+00 0.000619 0.000619 \n", "3 1.091776e-05 2.183551e-05 0.000000e+00 0.000620 0.000620 \n", "4 4.049176e-07 8.098352e-07 0.000000e+00 0.000620 0.000620 \n", "... ... ... ... ... ... \n", "2502495 9.953520e-07 2.266555e-03 5.509534e-149 0.000318 0.001450 \n", "2502496 9.817216e-07 2.217997e-03 2.043375e-150 0.000321 0.001429 \n", "2502497 9.706451e-07 2.179066e-03 7.578467e-152 0.000324 0.001412 \n", "2502498 9.621074e-07 2.149820e-03 2.810844e-153 0.000326 0.001400 \n", "2502499 9.564401e-07 2.129912e-03 5.405468e-155 0.000327 0.001391 \n", "\n", " Barite Celestite \n", "0 0.001 1.000000 \n", "1 0.001 1.000000 \n", "2 0.001 1.000000 \n", "3 0.001 1.000000 \n", "4 0.001 1.000000 \n", "... ... ... \n", "2502495 0.001 1.000014 \n", "2502496 0.001 1.000010 \n", "2502497 0.001 1.000006 \n", "2502498 0.001 1.000004 \n", "2502499 0.001 1.000001 \n", "\n", "[2502500 rows x 12 columns]" ] }, "execution_count": 102, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_design" ] }, { "cell_type": "code", "execution_count": 103, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
HOChargeH_0_O_0_BaClS_2_S_6_SrBariteCelestite
04.718614.0343868.1090452.697041e-262.210590e-152.020518e-024.001019e-020.000000e+000.0004940.0004940.0010000.693147
14.718614.0343908.3245811.446346e-262.473481e-151.088620e-022.165516e-020.000000e+000.0005520.0005520.0010000.693147
24.718614.0343948.8823413.889826e-282.769320e-152.943312e-045.885758e-040.000000e+000.0006180.0006180.0010000.693147
34.718614.0343958.9136321.442658e-292.777193e-151.091770e-052.183528e-050.000000e+000.0006200.0006200.0010000.693147
44.718614.0343958.9148375.350528e-312.777485e-154.049175e-078.098349e-070.000000e+000.0006200.0006200.0010000.693147
.......................................
25024954.718614.034373-8.4468785.424062e-1451.375204e-109.953515e-072.263990e-035.509534e-1490.0003180.0014490.0009990.693154
25024964.718614.034373-8.4566762.011675e-1461.377139e-109.817211e-072.215541e-032.043375e-1500.0003210.0014280.0009990.693152
25024974.718614.034374-8.4647367.460897e-1481.377819e-109.706446e-072.176695e-037.578467e-1520.0003240.0014110.0009990.693150
25024984.718614.034374-8.4708592.767237e-1491.371144e-109.621070e-072.147512e-032.810844e-1530.0003260.0013990.0009990.693149
25024994.718614.034374-8.4751025.321610e-1511.376026e-109.564396e-072.127647e-035.405468e-1550.0003270.0013900.0009990.693148
\n", "

2502500 rows × 12 columns

\n", "
" ], "text/plain": [ " H O Charge H_0_ O_0_ \\\n", "0 4.71861 4.034386 8.109045 2.697041e-26 2.210590e-15 \n", "1 4.71861 4.034390 8.324581 1.446346e-26 2.473481e-15 \n", "2 4.71861 4.034394 8.882341 3.889826e-28 2.769320e-15 \n", "3 4.71861 4.034395 8.913632 1.442658e-29 2.777193e-15 \n", "4 4.71861 4.034395 8.914837 5.350528e-31 2.777485e-15 \n", "... ... ... ... ... ... \n", "2502495 4.71861 4.034373 -8.446878 5.424062e-145 1.375204e-10 \n", "2502496 4.71861 4.034373 -8.456676 2.011675e-146 1.377139e-10 \n", "2502497 4.71861 4.034374 -8.464736 7.460897e-148 1.377819e-10 \n", "2502498 4.71861 4.034374 -8.470859 2.767237e-149 1.371144e-10 \n", "2502499 4.71861 4.034374 -8.475102 5.321610e-151 1.376026e-10 \n", "\n", " Ba Cl S_2_ S_6_ Sr \\\n", "0 2.020518e-02 4.001019e-02 0.000000e+00 0.000494 0.000494 \n", "1 1.088620e-02 2.165516e-02 0.000000e+00 0.000552 0.000552 \n", "2 2.943312e-04 5.885758e-04 0.000000e+00 0.000618 0.000618 \n", "3 1.091770e-05 2.183528e-05 0.000000e+00 0.000620 0.000620 \n", "4 4.049175e-07 8.098349e-07 0.000000e+00 0.000620 0.000620 \n", "... ... ... ... ... ... \n", "2502495 9.953515e-07 2.263990e-03 5.509534e-149 0.000318 0.001449 \n", "2502496 9.817211e-07 2.215541e-03 2.043375e-150 0.000321 0.001428 \n", "2502497 9.706446e-07 2.176695e-03 7.578467e-152 0.000324 0.001411 \n", "2502498 9.621070e-07 2.147512e-03 2.810844e-153 0.000326 0.001399 \n", "2502499 9.564396e-07 2.127647e-03 5.405468e-155 0.000327 0.001390 \n", "\n", " Barite Celestite \n", "0 0.001000 0.693147 \n", "1 0.001000 0.693147 \n", "2 0.001000 0.693147 \n", "3 0.001000 0.693147 \n", "4 0.001000 0.693147 \n", "... ... ... \n", "2502495 0.000999 0.693154 \n", "2502496 0.000999 0.693152 \n", "2502497 0.000999 0.693150 \n", "2502498 0.000999 0.693149 \n", "2502499 0.000999 0.693148 \n", "\n", "[2502500 rows x 12 columns]" ] }, "execution_count": 103, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_design_log" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [], "source": [ "X_train_log, y_train_log = log_scale(X_train, y_train, func_dict_in)\n", "X_test_log, y_test_log = log_scale(X_test, y_test, func_dict_in)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "train_min_log, train_max_log = get_min_max(X_train_log, y_train_log)\n", "test_min_log, test_max_log = get_min_max(X_test_log, y_test_log)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "def preprocess(data, func_dict, data_min, data_max):\n", " data = data.copy()\n", " for key in data.keys():\n", " if key != \"Class\":\n", " data[key] = (data[key] - data_min[key]) / (data_max[key] - data_min[key])\n", "\n", " return data\n", "\n", "def postprocess(data, func_dict, data_min, data_max):\n", " data = data.copy()\n", " for key in data.keys():\n", " if key != \"Class\":\n", " data[key] = data[key] * (data_max[key] - data_min[key]) + data_min[key]\n", " data[key] = np.vectorize(func_dict[key])(data[key])\n", " return data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Preprocess the data" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "pp_design = preprocess(df_design_log, func_dict_in, data_min_log, data_max_log)\n", "pp_results = preprocess(df_results_log, func_dict_in, data_min_log, data_max_log)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "X_train_preprocess = preprocess(X_train_log, func_dict_in, train_min_log, train_max_log)\n", "y_train_preprocess = preprocess(y_train_log, func_dict_in, train_min_log, train_max_log)\n", "\n", "X_test_preprocess = preprocess(X_test_log, func_dict_in, test_min_log, test_max_log)\n", "y_test_preprocess = preprocess(y_test_log, func_dict_in, test_min_log, test_max_log)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sample the data" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "X_train, X_val, y_train, y_val = sk.train_test_split(X_train_preprocess, y_train_preprocess, test_size = 0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Custom Loss function" ] }, { "cell_type": "code", "execution_count": 164, "metadata": {}, "outputs": [], "source": [ "def custom_loss_H20(df_design_log, df_result_log, data_min_log, data_max_log, func_dict_out, postprocess):\n", " df_result = postprocess(df_result_log, func_dict_out, data_min_log, data_max_log) \n", " return keras.losses.Huber + np.sum(((df_result['H'] / df_result['O']) - 2)**2)\n", "\n", "def loss_wrapper(data_min_log, data_max_log, func_dict_out, postprocess):\n", " def loss(df_design_log, df_result_log):\n", " return custom_loss_H20(df_design_log, df_result_log, data_min_log, data_max_log, func_dict_out, postprocess)\n", " return loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train the model" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n" ] }, { "ename": "ValueError", "evalue": "Attr 'Toutput_types' of 'OptionalFromValue' Op passed list of length 0 less than minimum 1.", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[71], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# measure time\u001b[39;00m\n\u001b[1;32m 2\u001b[0m start \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[0;32m----> 4\u001b[0m history \u001b[38;5;241m=\u001b[39m \u001b[43mmodel_simple\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_train\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43miloc\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43miloc\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidation_data\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_val\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43miloc\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_val\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43miloc\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[43m)\u001b[49m\n\u001b[1;32m 11\u001b[0m end \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTraining took \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m seconds\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(end \u001b[38;5;241m-\u001b[39m start))\n", "File \u001b[0;32m~/miniconda3/envs/ai/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py:122\u001b[0m, in \u001b[0;36mfilter_traceback..error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 119\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n\u001b[1;32m 120\u001b[0m \u001b[38;5;66;03m# To get the full stack trace, call:\u001b[39;00m\n\u001b[1;32m 121\u001b[0m \u001b[38;5;66;03m# `keras.config.disable_traceback_filtering()`\u001b[39;00m\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\u001b[38;5;241m.\u001b[39mwith_traceback(filtered_tb) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 123\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 124\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m filtered_tb\n", "File \u001b[0;32m~/miniconda3/envs/ai/lib/python3.11/site-packages/keras/src/backend/tensorflow/trainer.py:131\u001b[0m, in \u001b[0;36mTensorFlowTrainer._make_function..multi_step_on_iterator\u001b[0;34m(iterator)\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[38;5;129m@tf\u001b[39m\u001b[38;5;241m.\u001b[39mautograph\u001b[38;5;241m.\u001b[39mexperimental\u001b[38;5;241m.\u001b[39mdo_not_convert\n\u001b[1;32m 129\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmulti_step_on_iterator\u001b[39m(iterator):\n\u001b[1;32m 130\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msteps_per_execution \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m--> 131\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexperimental\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mOptional\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_value\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 132\u001b[0m \u001b[43m \u001b[49m\u001b[43mone_step_on_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43miterator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_next\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 133\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 135\u001b[0m \u001b[38;5;66;03m# the spec is set lazily during the tracing of `tf.while_loop`\u001b[39;00m\n\u001b[1;32m 136\u001b[0m empty_outputs \u001b[38;5;241m=\u001b[39m tf\u001b[38;5;241m.\u001b[39mexperimental\u001b[38;5;241m.\u001b[39mOptional\u001b[38;5;241m.\u001b[39mempty(\u001b[38;5;28;01mNone\u001b[39;00m)\n", "\u001b[0;31mValueError\u001b[0m: Attr 'Toutput_types' of 'OptionalFromValue' Op passed list of length 0 less than minimum 1." ] } ], "source": [ "# measure time\n", "start = time.time()\n", "\n", "history = model_simple.fit(X_train.iloc[:, :-1], \n", " y_train.iloc[:, :-1], \n", " batch_size = batch_size, \n", " epochs = 5, \n", " validation_data = (X_val.iloc[:,:-1], y_val.iloc[:, :-1])\n", ")\n", "\n", "end = time.time()\n", "\n", "print(\"Training took {} seconds\".format(end - start))" ] }, { "cell_type": "code", "execution_count": 146, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(history.history[\"loss\"], \"o-\", label = \"Training Loss\")\n", "plt.xlabel(\"Epoch\")\n", "# plt.yscale('log')\n", "plt.ylabel(\"Loss (Huber)\")\n", "plt.grid('on')\n", "\n", "plt.savefig(\"loss_all.png\", dpi=300)\n" ] }, { "cell_type": "code", "execution_count": 147, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(history.history[\"loss\"][1:], \"o-\", label = \"Training Loss\")\n", "plt.xlabel(\"Epoch\")\n", "# plt.yscale('log')\n", "plt.ylabel(\"Loss (Huber)\")\n", "plt.grid('on')\n", "plt.savefig(\"loss_1_to_end.png\", dpi=300)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test the model" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m15641/15641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 233us/step - loss: 0.0324\n" ] }, { "data": { "text/plain": [ "0.032423071563243866" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# test on all test data\n", "model_simple.evaluate(X_test_preprocess.iloc[:,:-1], y_test_preprocess.iloc[:, :-1])" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m15451/15451\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 234us/step - loss: 0.0313\n" ] }, { "data": { "text/plain": [ "0.031290605664253235" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# test on non-reactive data\n", "model_simple.evaluate(X_test_preprocess[X_test_preprocess['Class'] == 0].iloc[:,:-1], y_test_preprocess[X_test_preprocess['Class'] == 0].iloc[:,:-1])" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m190/190\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 295us/step - loss: 0.1246\n" ] }, { "data": { "text/plain": [ "0.12462512403726578" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# test on reactive data\n", "model_simple.evaluate(X_test_preprocess[X_test_preprocess['Class'] == 1].iloc[:,:-1], y_test_preprocess[X_test_preprocess['Class'] == 1].iloc[:, :-1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Save the model" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "# Save the model\n", "model.save(\"Barite_50_Model_additional_species.keras\")" ] } ], "metadata": { "kernelspec": { "display_name": "ai", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 2 }