{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## General Information" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook is used to train a simple neural network model to predict the chemistry in the barite benchmark (50x50 grid). The training data is stored in the repository using **git large file storage** and can be downloaded after the installation of git lfs using the `git lfs pull` command.\n", "\n", "It is then recommended to create a Python environment using miniconda. The necessary dependencies are contained in `environment.yml` and can be installed using `conda env create -f environment.yml`.\n", "\n", "The data set is divided into a design and result part and consists of the iterations of a reference simulation. The design part of the data set contains the chemical concentrations at time $t$ and the result part at time $t+1$, which are to be learned by the model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup Libraries" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-01-15 14:03:45.613137: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "E0000 00:00:1736946225.710416 5544 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "E0000 00:00:1736946225.737063 5544 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "2025-01-15 14:03:45.952881: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running Keras in version 3.7.0\n" ] } ], "source": [ "import keras\n", "print(\"Running Keras in version {}\".format(keras.__version__))\n", "\n", "import h5py\n", "import numpy as np\n", "import pandas as pd\n", "import time\n", "import sklearn.model_selection as sk\n", "import matplotlib.pyplot as plt\n", "from sklearn.cluster import KMeans\n", "from imblearn.over_sampling import SMOTE\n", "from imblearn.under_sampling import RandomUnderSampler\n", "from imblearn.over_sampling import RandomOverSampler\n", "from collections import Counter\n", "import os" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define parameters" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "I0000 00:00:1736946248.679800 5544 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13393 MB memory: -> device: 0, name: NVIDIA A2, pci bus id: 0000:82:00.0, compute capability: 8.6\n" ] } ], "source": [ "dtype = \"float32\"\n", "activation = \"relu\"\n", "\n", "lr = 0.001\n", "batch_size = 512\n", "epochs = 50 # default 400 epochs\n", "\n", "lr_schedule = keras.optimizers.schedules.ExponentialDecay(\n", " initial_learning_rate=lr,\n", " decay_steps=2000,\n", " decay_rate=0.9,\n", " staircase=True\n", ")\n", "\n", "optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)\n", "loss = keras.losses.Huber()\n", "\n", "sample_fraction = 0.8" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup the model" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Model: \"sequential\"\n",
"\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense (Dense) │ (None, 128) │ 1,664 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_1 (Dense) │ (None, 128) │ 16,512 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_2 (Dense) │ (None, 12) │ 1,548 │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m1,664\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m16,512\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_2 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m1,548\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Total params: 19,724 (77.05 KB)\n", "\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m19,724\u001b[0m (77.05 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Trainable params: 19,724 (77.05 KB)\n", "\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m19,724\u001b[0m (77.05 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Non-trainable params: 0 (0.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model_simple = keras.Sequential(\n", " [\n", " keras.Input(shape = (12,), dtype = \"float32\"),\n", " keras.layers.Dense(units = 128, activation = \"relu\", dtype = \"float32\"),\n", " keras.layers.Dense(units = 128, activation = \"relu\", dtype = \"float32\"),\n", " keras.layers.Dense(units = 12, dtype = \"float32\")\n", " ]\n", ")\n", "\n", "model_simple.compile(optimizer=optimizer, loss = loss)\n", "model_simple.summary()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Model: \"sequential_1\"\n",
"\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential_1\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_3 (Dense) │ (None, 512) │ 6,656 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_4 (Dense) │ (None, 1024) │ 525,312 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_5 (Dense) │ (None, 512) │ 524,800 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_6 (Dense) │ (None, 12) │ 6,156 │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ dense_3 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m6,656\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_4 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1024\u001b[0m) │ \u001b[38;5;34m525,312\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_5 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m524,800\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ dense_6 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m6,156\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Total params: 1,062,924 (4.05 MB)\n", "\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,062,924\u001b[0m (4.05 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Trainable params: 1,062,924 (4.05 MB)\n", "\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,062,924\u001b[0m (4.05 MB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Non-trainable params: 0 (0.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model_large = keras.Sequential(\n", " [keras.layers.Input(shape=(12,), dtype=dtype),\n", " keras.layers.Dense(512, activation='relu', dtype=dtype),\n", " keras.layers.Dense(1024, activation='relu', dtype=dtype),\n", " keras.layers.Dense(512, activation='relu', dtype=dtype),\n", " keras.layers.Dense(12, dtype=dtype)\n", " ])\n", "\n", "model_large.compile(optimizer=optimizer, loss = loss)\n", "model_large.summary()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define some functions and helper classes" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def Safelog(val):\n", " # get range of vector\n", " if val > 0:\n", " return np.log10(val)\n", " elif val < 0:\n", " return -np.log10(-val)\n", " else:\n", " return 0\n", "\n", "def Safeexp(val):\n", " if val > 0:\n", " return -10 ** -val\n", " elif val < 0:\n", " return 10 ** val\n", " else:\n", " return 0\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# ? Why does the charge is using another logarithm than the other species\n", "\n", "func_dict_in = {\n", " \"H\" : np.log1p,\n", " \"O\" : np.log1p,\n", " \"Charge\" : Safelog,\n", " \"H_0_\" : np.log1p,\n", " \"O_0_\" : np.log1p,\n", " \"Ba\" : np.log1p,\n", " \"Cl\" : np.log1p,\n", " \"S_2_\" : np.log1p,\n", " \"S_6_\" : np.log1p,\n", " \"Sr\" : np.log1p,\n", " \"Barite\" : np.log1p,\n", " \"Celestite\" : np.log1p,\n", "}\n", "\n", "func_dict_out = {\n", " \"H\" : np.expm1,\n", " \"O\" : np.expm1,\n", " \"Charge\" : Safeexp,\n", " \"H_0_\" : np.expm1,\n", " \"O_0_\" : np.expm1,\n", " \"Ba\" : np.expm1,\n", " \"Cl\" : np.expm1,\n", " \"S_2_\" : np.expm1,\n", " \"S_6_\" : np.expm1,\n", " \"Sr\" : np.expm1,\n", " \"Barite\" : np.expm1,\n", " \"Celestite\" : np.expm1,\n", "}\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Read data from `.h5` file and convert it to a `pandas.DataFrame`" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "os.chdir('/mnt/beegfs/home/signer/projects/model-training')\n", "data_file = h5py.File(\"Barite_50_Data_training.h5\")\n", "\n", "design = data_file[\"design\"]\n", "results = data_file[\"result\"]\n", "\n", "df_design = pd.DataFrame(np.array(design[\"data\"]).transpose(), columns = design[\"names\"].asstr())\n", "df_results = pd.DataFrame(np.array(results[\"data\"]).transpose(), columns = results[\"names\"].asstr())\n", "\n", "data_file.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Classify each cell with kmeans" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/mnt/beegfs/home/signer/.conda/envs/ai/lib/python3.11/site-packages/sklearn/base.py:1473: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (2). Possibly due to duplicate points in X.\n", " return fit_method(estimator, *args, **kwargs)\n" ] } ], "source": [ "# widget with slider for the index\n", "\n", "class_label_design = np.array([])\n", "class_label_result = np.array([])\n", "\n", "\n", "i = 1000\n", "for i in range(0,1001):\n", " field_design = np.array(df_design['Barite'][(i*2500):(i*2500+2500)]).reshape(50,50)\n", " field_result = np.array(df_results['Barite'][(i*2500):(i*2500+2500)]).reshape(50,50)\n", " \n", " kmeans_design = KMeans(n_clusters=2, random_state=0).fit(field_design.reshape(-1,1))\n", " kmeans_result = KMeans(n_clusters=2, random_state=0).fit(field_result.reshape(-1,1))\n", " \n", " class_label_design = np.append(class_label_design.astype(int), kmeans_design.labels_)\n", " class_label_result = np.append(class_label_result.astype(int), kmeans_result.labels_)\n", " \n", "\n", "\n", "class_label_design = pd.DataFrame(class_label_design, columns = [\"Class\"])\n", "class_label_result = pd.DataFrame(class_label_result, columns = [\"Class\"])\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "if(\"Class\" in df_design.columns and \"Class\" in df_results.columns):\n", " print(\"Class column already exists\")\n", "else:\n", " df_design = pd.concat([df_design, class_label_design], axis=1)\n", " df_results = pd.concat([df_results, class_label_design], axis=1)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Amount class 0: 0.9879380619380619\n", "Amount class 1: 0.012061938061938062\n" ] } ], "source": [ "counter = Counter(df_design.iloc[:,-1])\n", "print(\"Amount class 0:\", counter[0] / (counter[0] + counter[1]) )\n", "print(\"Amount class 1:\", counter[1] / (counter[0] + counter[1]) )\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "
| \n", " | H | \n", "O | \n", "Charge | \n", "H_0_ | \n", "O_0_ | \n", "Ba | \n", "Cl | \n", "S_2_ | \n", "S_6_ | \n", "Sr | \n", "Barite | \n", "Celestite | \n", "Class | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "111.012434 | \n", "55.508192 | \n", "-7.779554e-09 | \n", "2.697041e-26 | \n", "2.210590e-15 | \n", "2.041069e-02 | \n", "4.082138e-02 | \n", "0.000000e+00 | \n", "0.000494 | \n", "0.000494 | \n", "0.001 | \n", "1.000000 | \n", "0 | \n", "
| 1 | \n", "111.012434 | \n", "55.508427 | \n", "-4.736083e-09 | \n", "1.446346e-26 | \n", "2.473481e-15 | \n", "1.094567e-02 | \n", "2.189133e-02 | \n", "0.000000e+00 | \n", "0.000553 | \n", "0.000553 | \n", "0.001 | \n", "1.000000 | \n", "0 | \n", "
| 2 | \n", "111.012434 | \n", "55.508691 | \n", "-1.311169e-09 | \n", "3.889826e-28 | \n", "2.769320e-15 | \n", "2.943745e-04 | \n", "5.887491e-04 | \n", "0.000000e+00 | \n", "0.000619 | \n", "0.000619 | \n", "0.001 | \n", "1.000000 | \n", "0 | \n", "
| 3 | \n", "111.012434 | \n", "55.508698 | \n", "-1.220023e-09 | \n", "1.442658e-29 | \n", "2.777193e-15 | \n", "1.091776e-05 | \n", "2.183551e-05 | \n", "0.000000e+00 | \n", "0.000620 | \n", "0.000620 | \n", "0.001 | \n", "1.000000 | \n", "0 | \n", "
| 4 | \n", "111.012434 | \n", "55.508699 | \n", "-1.216643e-09 | \n", "5.350528e-31 | \n", "2.777485e-15 | \n", "4.049176e-07 | \n", "8.098352e-07 | \n", "0.000000e+00 | \n", "0.000620 | \n", "0.000620 | \n", "0.001 | \n", "1.000000 | \n", "0 | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 2502495 | \n", "111.012434 | \n", "55.507488 | \n", "3.573728e-09 | \n", "5.424062e-145 | \n", "1.375204e-10 | \n", "9.953520e-07 | \n", "2.266555e-03 | \n", "5.509534e-149 | \n", "0.000318 | \n", "0.001450 | \n", "0.001 | \n", "1.000014 | \n", "0 | \n", "
| 2502496 | \n", "111.012434 | \n", "55.507501 | \n", "3.494007e-09 | \n", "2.011675e-146 | \n", "1.377139e-10 | \n", "9.817216e-07 | \n", "2.217997e-03 | \n", "2.043375e-150 | \n", "0.000321 | \n", "0.001429 | \n", "0.001 | \n", "1.000010 | \n", "0 | \n", "
| 2502497 | \n", "111.012434 | \n", "55.507512 | \n", "3.429764e-09 | \n", "7.460897e-148 | \n", "1.377819e-10 | \n", "9.706451e-07 | \n", "2.179066e-03 | \n", "7.578467e-152 | \n", "0.000324 | \n", "0.001412 | \n", "0.001 | \n", "1.000006 | \n", "0 | \n", "
| 2502498 | \n", "111.012434 | \n", "55.507520 | \n", "3.381745e-09 | \n", "2.767237e-149 | \n", "1.371144e-10 | \n", "9.621074e-07 | \n", "2.149820e-03 | \n", "2.810844e-153 | \n", "0.000326 | \n", "0.001400 | \n", "0.001 | \n", "1.000004 | \n", "0 | \n", "
| 2502499 | \n", "111.012434 | \n", "55.507525 | \n", "3.348864e-09 | \n", "5.321610e-151 | \n", "1.376026e-10 | \n", "9.564401e-07 | \n", "2.129912e-03 | \n", "5.405468e-155 | \n", "0.000327 | \n", "0.001391 | \n", "0.001 | \n", "1.000001 | \n", "0 | \n", "
2502500 rows × 13 columns
\n", "