{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## General Information" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook is used to train a simple neural network model to predict the chemistry in the barite benchmark (50x50 grid). The training data is stored in the repository using **git large file storage** and can be downloaded after the installation of git lfs using the `git lfs pull` command.\n", "\n", "It is then recommended to create a Python environment using miniconda. The necessary dependencies are contained in `environment.yml` and can be installed using `conda env create -f environment.yml`.\n", "\n", "The data set is divided into a design and result part and consists of the iterations of a reference simulation. The design part of the data set contains the chemical concentrations at time $t$ and the result part at time $t+1$, which are to be learned by the model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup Libraries" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running Keras in version 3.8.0\n" ] } ], "source": [ "import keras\n", "import h5py\n", "import numpy as np\n", "import pandas as pd\n", "import time\n", "import sklearn.model_selection as sk\n", "import matplotlib.pyplot as plt\n", "from sklearn.cluster import KMeans\n", "from sklearn.pipeline import Pipeline, make_pipeline\n", "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n", "from imblearn.over_sampling import SMOTE\n", "from imblearn.under_sampling import RandomUnderSampler\n", "from imblearn.over_sampling import RandomOverSampler\n", "from collections import Counter\n", "import os\n", "from preprocessing import *\n", "from sklearn import set_config\n", "from importlib import reload\n", "set_config(transform_output = \"pandas\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define parameters" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "dtype = \"float32\"\n", "activation = \"relu\"\n", "\n", "lr = 0.001\n", "batch_size = 512\n", "epochs = 50 # default 400 epochs\n", "\n", "lr_schedule = keras.optimizers.schedules.ExponentialDecay(\n", " initial_learning_rate=lr,\n", " decay_steps=2000,\n", " decay_rate=0.9,\n", " staircase=True\n", ")\n", "\n", "optimizer_simple = keras.optimizers.Adam(learning_rate=lr_schedule)\n", "optimizer_large = keras.optimizers.Adam(learning_rate=lr_schedule)\n", "\n", "loss = keras.losses.MeanSquaredError()\n", "\n", "sample_fraction = 0.8" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup the model" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Model: \"sequential\"\n",
"\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense (Dense) │ (None, 128) │ 1,664 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_1 (Dense) │ (None, 128) │ 16,512 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_2 (Dense) │ (None, 12) │ 1,548 │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m1,664\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m16,512\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_2 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m1,548\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Total params: 19,724 (77.05 KB)\n", "\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m19,724\u001b[0m (77.05 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Trainable params: 19,724 (77.05 KB)\n", "\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m19,724\u001b[0m (77.05 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Non-trainable params: 0 (0.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# small model\n", "model_simple = keras.Sequential(\n", " [\n", " keras.Input(shape = (12,), dtype = \"float32\"),\n", " keras.layers.Dense(units = 128, activation = \"relu\", dtype = \"float32\"),\n", " keras.layers.Dense(units = 128, activation = \"relu\", dtype = \"float32\"),\n", " keras.layers.Dense(units = 12, dtype = \"float32\")\n", " ]\n", ")\n", "\n", "model_simple.compile(optimizer=optimizer_simple, loss = loss)\n", "model_simple.summary()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Model: \"sequential_1\"\n",
"\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential_1\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_3 (Dense) │ (None, 512) │ 6,656 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_4 (Dense) │ (None, 1024) │ 525,312 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_5 (Dense) │ (None, 512) │ 524,800 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_6 (Dense) │ (None, 12) │ 6,156 │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_3 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m6,656\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_4 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1024\u001b[0m) │ \u001b[38;5;34m525,312\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_5 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m524,800\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_6 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m6,156\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Total params: 1,062,924 (4.05 MB)\n", "\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,062,924\u001b[0m (4.05 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Trainable params: 1,062,924 (4.05 MB)\n", "\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,062,924\u001b[0m (4.05 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Non-trainable params: 0 (0.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# large model\n", "model_large = keras.Sequential(\n", " [keras.layers.Input(shape=(12,), dtype=dtype),\n", " keras.layers.Dense(512, activation='relu', dtype=dtype),\n", " keras.layers.Dense(1024, activation='relu', dtype=dtype),\n", " keras.layers.Dense(512, activation='relu', dtype=dtype),\n", " keras.layers.Dense(12, dtype=dtype)\n", " ])\n", "\n", "model_large.compile(optimizer=optimizer_large, loss = loss)\n", "model_large.summary()\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# model from paper" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define transformer functions" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def Safelog(val):\n", " # get range of vector\n", " if val > 0:\n", " return np.log10(val)\n", " elif val < 0:\n", " return -np.log10(-val)\n", " else:\n", " return 0\n", "\n", "def Safeexp(val):\n", " if val > 0:\n", " return -10 ** -val\n", " elif val < 0:\n", " return 10 ** val\n", " else:\n", " return 0" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# ? Why does the charge is using another logarithm than the other species\n", "\n", "func_dict_in = {\n", " \"H\" : np.log1p,\n", " \"O\" : np.log1p,\n", " \"Charge\" : Safelog,\n", " \"H_0_\" : np.log1p,\n", " \"O_0_\" : np.log1p,\n", " \"Ba\" : np.log1p,\n", " \"Cl\" : np.log1p,\n", " \"S_2_\" : np.log1p,\n", " \"S_6_\" : np.log1p,\n", " \"Sr\" : np.log1p,\n", " \"Barite\" : np.log1p,\n", " \"Celestite\" : np.log1p,\n", "}\n", "\n", "func_dict_out = {\n", " \"H\" : np.expm1,\n", " \"O\" : np.expm1,\n", " \"Charge\" : Safeexp,\n", " \"H_0_\" : np.expm1,\n", " \"O_0_\" : np.expm1,\n", " \"Ba\" : np.expm1,\n", " \"Cl\" : np.expm1,\n", " \"S_2_\" : np.expm1,\n", " \"S_6_\" : np.expm1,\n", " \"Sr\" : np.expm1,\n", " \"Barite\" : np.expm1,\n", " \"Celestite\" : np.expm1,\n", "}\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Read data from `.h5` file and convert it to a `pandas.DataFrame`" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# os.chdir('/mnt/beegfs/home/signer/projects/model-training')\n", "data_file = h5py.File(\"Barite_50_Data_training.h5\")\n", "\n", "design = data_file[\"design\"]\n", "results = data_file[\"result\"]\n", "\n", "df_design = pd.DataFrame(np.array(design[\"data\"]).transpose(), columns = np.array(design[\"names\"].asstr()))\n", "df_results = pd.DataFrame(np.array(results[\"data\"]).transpose(), columns = np.array(results[\"names\"].asstr()))\n", "\n", "data_file.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Preprocess Data\n", "\n", "The data are preprocessed in the following way:\n", "\n", "1. Label data points in the `design` dataset with `reactive` and `non-reactive` labels using kmeans clustering\n", "2. Transform `design` and `results` data set into log-scaled data.\n", "3. Split data into training and test sets.\n", "4. Learn scaler on training data for `design` and `results` together (option `global`) or individual (option `individual`).\n", "5. Transform training and test data.\n", "6. Split training data into training and validation dataset." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/hannessigner/miniforge3/envs/ai/lib/python3.12/site-packages/sklearn/base.py:1474: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (2). Possibly due to duplicate points in X.\n", " return fit_method(estimator, *args, **kwargs)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Amount class 0 before: 0.9879169719169719\n", "Amount class 1 before: 0.012083028083028084\n", "Using Oversampling\n", "Amount class 0 after: 0.5\n", "Amount class 1 after: 0.5\n" ] } ], "source": [ "X_train, X_val, X_test, y_train, y_val, y_test, scaler_X, scaler_y = preprocessing_training(df_design, df_results, func_dict_in, func_dict_out, \"over\", 'global', 0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Custom Loss function" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def custom_loss_H20(df_design_log, df_result_log, data_min_log, data_max_log, func_dict_out, postprocess):\n", " df_result = postprocess(df_result_log, func_dict_out, data_min_log, data_max_log) \n", " return keras.losses.Huber + np.sum(((df_result['H'] / df_result['O']) - 2)**2)\n", "\n", "def loss_wrapper(data_min_log, data_max_log, func_dict_out, postprocess):\n", " def loss(df_design_log, df_result_log):\n", " return custom_loss_H20(df_design_log, df_result_log, data_min_log, data_max_log, func_dict_out, postprocess)\n", " return loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train the model" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/20\n", "\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 2ms/step - loss: 0.0021 - val_loss: 3.4232e-05\n", "Epoch 2/20\n", "\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 2ms/step - loss: 3.5182e-05 - val_loss: 3.3009e-05\n", "Epoch 3/20\n", "\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 2ms/step - loss: 3.3553e-05 - val_loss: 3.1858e-05\n", "Epoch 4/20\n", "\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 2ms/step - loss: 3.2530e-05 - val_loss: 3.1686e-05\n", "Epoch 5/20\n", "\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 2ms/step - loss: 3.1540e-05 - val_loss: 3.1268e-05\n", "Epoch 6/20\n", "\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 2ms/step - loss: 3.1264e-05 - val_loss: 3.1947e-05\n", "Epoch 7/20\n", "\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 2ms/step - loss: 3.1342e-05 - val_loss: 3.1175e-05\n", "Epoch 8/20\n", "\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 2ms/step - loss: 3.1101e-05 - val_loss: 3.1003e-05\n", "Epoch 9/20\n", "\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 2ms/step - loss: 3.1035e-05 - val_loss: 3.0818e-05\n", "Epoch 10/20\n", "\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 2ms/step - loss: 3.0654e-05 - val_loss: 3.0667e-05\n", "Epoch 11/20\n", "\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 2ms/step - loss: 3.0462e-05 - val_loss: 3.0639e-05\n", "Epoch 12/20\n", "\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 2ms/step - loss: 3.0563e-05 - val_loss: 3.0643e-05\n", "Epoch 13/20\n", "\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m14s\u001b[0m 2ms/step - loss: 3.1567e-05 - val_loss: 3.0610e-05\n", "Epoch 14/20\n", "\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 2ms/step - loss: 3.0625e-05 - val_loss: 3.0598e-05\n", "Epoch 15/20\n", "\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 2ms/step - loss: 3.0962e-05 - val_loss: 3.0589e-05\n", "Epoch 16/20\n", "\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 2ms/step - loss: 3.1131e-05 - val_loss: 3.0598e-05\n", "Epoch 17/20\n", "\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 2ms/step - loss: 3.0939e-05 - val_loss: 3.0580e-05\n", "Epoch 18/20\n", "\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 2ms/step - loss: 3.1100e-05 - val_loss: 3.0580e-05\n", "Epoch 19/20\n", "\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 2ms/step - loss: 3.0586e-05 - val_loss: 3.0579e-05\n", "Epoch 20/20\n", "\u001b[1m7823/7823\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 2ms/step - loss: 3.0749e-05 - val_loss: 3.0576e-05\n", "Training took 295.5790858268738 seconds\n" ] } ], "source": [ "# measure time\n", "start = time.time()\n", "callback = keras.callbacks.EarlyStopping(monitor='loss', patience=3)\n", "history = model_simple.fit(X_train.iloc[:, X_train.columns != \"Class\"], \n", " y_train.iloc[:, y_train.columns != \"Class\"], \n", " batch_size = batch_size, \n", " epochs = 20, \n", " validation_data = (X_val.iloc[:, X_val.columns != \"Class\"], y_val.iloc[:, y_val.columns != \"Class\"]),\n", " callbacks = [callback])\n", "\n", "end = time.time()\n", "\n", "print(\"Training took {} seconds\".format(end - start))" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 595us/step\n" ] }, { "data": { "image/png": "", "text/plain": [ "
| \n", " | H | \n", "O | \n", "Charge | \n", "H_0_ | \n", "O_0_ | \n", "Ba | \n", "Cl | \n", "S_2_ | \n", "S_6_ | \n", "Sr | \n", "Barite | \n", "Celestite | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "111.012434 | \n", "55.508700 | \n", "-1.216415e-09 | \n", "0.0 | \n", "2.217711e-13 | \n", "4.495355e-07 | \n", "1.532249e-12 | \n", "0.0 | \n", "0.000621 | \n", "0.000620 | \n", "0.001000 | \n", "1.000000 | \n", "
| 1 | \n", "111.012434 | \n", "55.508700 | \n", "-1.222504e-09 | \n", "0.0 | \n", "1.312902e-12 | \n", "4.500359e-07 | \n", "1.044603e-08 | \n", "0.0 | \n", "0.000621 | \n", "0.000620 | \n", "0.001000 | \n", "1.000000 | \n", "
| 2 | \n", "111.012434 | \n", "55.508699 | \n", "-1.220407e-09 | \n", "0.0 | \n", "1.614098e-12 | \n", "4.500563e-07 | \n", "4.907802e-07 | \n", "0.0 | \n", "0.000621 | \n", "0.000620 | \n", "0.001000 | \n", "1.000000 | \n", "
| 3 | \n", "111.012434 | \n", "55.508695 | \n", "-1.216831e-09 | \n", "0.0 | \n", "2.293739e-12 | \n", "4.504482e-07 | \n", "4.772370e-06 | \n", "0.0 | \n", "0.000620 | \n", "0.000622 | \n", "0.001000 | \n", "1.000000 | \n", "
| 4 | \n", "111.012434 | \n", "55.508679 | \n", "-1.216842e-09 | \n", "0.0 | \n", "2.641545e-12 | \n", "4.534070e-07 | \n", "2.200220e-05 | \n", "0.0 | \n", "0.000616 | \n", "0.000626 | \n", "0.001000 | \n", "1.000000 | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 995 | \n", "111.012434 | \n", "55.506410 | \n", "2.170844e-08 | \n", "0.0 | \n", "4.777415e-11 | \n", "1.550089e-04 | \n", "9.784038e-02 | \n", "0.0 | \n", "0.000048 | \n", "0.048814 | \n", "0.149476 | \n", "0.846327 | \n", "
| 996 | \n", "111.012434 | \n", "55.506410 | \n", "2.166504e-08 | \n", "0.0 | \n", "4.808612e-11 | \n", "1.559012e-04 | \n", "9.785750e-02 | \n", "0.0 | \n", "0.000048 | \n", "0.048821 | \n", "0.151708 | \n", "0.844093 | \n", "
| 997 | \n", "111.012434 | \n", "55.506409 | \n", "2.162167e-08 | \n", "0.0 | \n", "4.811205e-11 | \n", "1.567226e-04 | \n", "9.787459e-02 | \n", "0.0 | \n", "0.000048 | \n", "0.048829 | \n", "0.153945 | \n", "0.841856 | \n", "
| 998 | \n", "111.012434 | \n", "55.506409 | \n", "2.157995e-08 | \n", "0.0 | \n", "4.815004e-11 | \n", "1.574812e-04 | \n", "9.789167e-02 | \n", "0.0 | \n", "0.000048 | \n", "0.048836 | \n", "0.156185 | \n", "0.839614 | \n", "
| 999 | \n", "111.012434 | \n", "55.506409 | \n", "2.153938e-08 | \n", "0.0 | \n", "4.815067e-11 | \n", "1.581835e-04 | \n", "9.790872e-02 | \n", "0.0 | \n", "0.000048 | \n", "0.048844 | \n", "0.158428 | \n", "0.837370 | \n", "
1000 rows × 12 columns
\n", "| \n", " | H | \n", "O | \n", "Charge | \n", "H_0_ | \n", "O_0_ | \n", "Ba | \n", "Cl | \n", "S_2_ | \n", "S_6_ | \n", "Sr | \n", "Barite | \n", "Celestite | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "111.012428 | \n", "55.508682 | \n", "-1.193141e-09 | \n", "3.397941e-15 | \n", "2.128961e-13 | \n", "-0.000012 | \n", "0.000021 | \n", "-1.191799e-17 | \n", "0.000620 | \n", "0.000630 | \n", "0.000985 | \n", "1.000231 | \n", "
| 1 | \n", "111.012428 | \n", "55.508682 | \n", "-1.202185e-09 | \n", "2.918474e-15 | \n", "1.019832e-12 | \n", "-0.000016 | \n", "0.000008 | \n", "-7.920033e-18 | \n", "0.000620 | \n", "0.000630 | \n", "0.000946 | \n", "1.000121 | \n", "
| 2 | \n", "111.012428 | \n", "55.508682 | \n", "-1.203471e-09 | \n", "1.785468e-15 | \n", "2.398433e-12 | \n", "-0.000015 | \n", "-0.000013 | \n", "-9.158671e-18 | \n", "0.000620 | \n", "0.000627 | \n", "0.000913 | \n", "0.999977 | \n", "
| 3 | \n", "111.012428 | \n", "55.508682 | \n", "-1.199235e-09 | \n", "1.746077e-15 | \n", "2.357316e-12 | \n", "-0.000016 | \n", "-0.000011 | \n", "-9.246642e-18 | \n", "0.000619 | \n", "0.000629 | \n", "0.000916 | \n", "0.999936 | \n", "
| 4 | \n", "111.012428 | \n", "55.508682 | \n", "-1.197043e-09 | \n", "1.533956e-15 | \n", "2.670053e-12 | \n", "-0.000019 | \n", "0.000003 | \n", "-9.144569e-18 | \n", "0.000615 | \n", "0.000635 | \n", "0.000943 | \n", "0.999769 | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 995 | \n", "111.012428 | \n", "55.506416 | \n", "2.189995e-08 | \n", "-3.792428e-15 | \n", "4.785985e-11 | \n", "0.000279 | \n", "0.097642 | \n", "1.182626e-16 | \n", "0.000051 | \n", "0.048738 | \n", "0.149016 | \n", "0.844585 | \n", "
| 996 | \n", "111.012428 | \n", "55.506416 | \n", "2.188781e-08 | \n", "-3.910681e-15 | \n", "4.822730e-11 | \n", "0.000279 | \n", "0.097665 | \n", "1.169111e-16 | \n", "0.000051 | \n", "0.048751 | \n", "0.151241 | \n", "0.842272 | \n", "
| 997 | \n", "111.012428 | \n", "55.506416 | \n", "2.184875e-08 | \n", "-3.749360e-15 | \n", "4.824932e-11 | \n", "0.000279 | \n", "0.097687 | \n", "1.146696e-16 | \n", "0.000051 | \n", "0.048761 | \n", "0.153477 | \n", "0.840015 | \n", "
| 998 | \n", "111.012428 | \n", "55.506416 | \n", "2.180415e-08 | \n", "-3.500642e-15 | \n", "4.816323e-11 | \n", "0.000279 | \n", "0.097710 | \n", "1.119656e-16 | \n", "0.000051 | \n", "0.048770 | \n", "0.155720 | \n", "0.837773 | \n", "
| 999 | \n", "111.012428 | \n", "55.506416 | \n", "2.177183e-08 | \n", "-3.375572e-15 | \n", "4.822685e-11 | \n", "0.000279 | \n", "0.097732 | \n", "1.095921e-16 | \n", "0.000050 | \n", "0.048780 | \n", "0.157962 | \n", "0.835504 | \n", "
1000 rows × 12 columns
\n", "