add script for experiments

This commit is contained in:
Hannes Signer 2025-03-28 16:01:42 +01:00
parent 04f5c40b29
commit 4f954cbc84
3 changed files with 826 additions and 28 deletions

View File

@ -52,9 +52,21 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: './projects/model-training/src/'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mos\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m os\u001b[38;5;241m.\u001b[39mchdir(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m./projects/model-training/src/\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: './projects/model-training/src/'"
]
}
],
"source": [
"import os\n",
"os.chdir(\"./projects/model-training/src/\")"
@ -69,12 +81,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2025-03-28 14:24:13.743271: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
"E0000 00:00:1743168253.897216 16215 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"E0000 00:00:1743168253.941693 16215 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"2025-03-28 14:24:14.439356: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
"2025-03-28 14:41:07.244286: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
"2025-03-28 14:41:07.369762: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
]
}
],
@ -271,7 +280,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 8,
"metadata": {},
"outputs": [
{
@ -281,7 +290,7 @@
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[6], line 12\u001b[0m\n\u001b[1;32m 8\u001b[0m ax3 \u001b[38;5;241m=\u001b[39m axes[\u001b[38;5;241m2\u001b[39m, i]\n\u001b[1;32m 9\u001b[0m ax4 \u001b[38;5;241m=\u001b[39m axes[\u001b[38;5;241m3\u001b[39m, i]\n\u001b[1;32m 11\u001b[0m im1 \u001b[38;5;241m=\u001b[39m ax1\u001b[38;5;241m.\u001b[39mimshow(\n\u001b[0;32m---> 12\u001b[0m np\u001b[38;5;241m.\u001b[39marray((X_manual[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCl\u001b[39m\u001b[38;5;124m\"\u001b[39m])[(timestep \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2500\u001b[39m) : (timestep \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2500\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m2500\u001b[39m)])\u001b[38;5;241m.\u001b[39mreshape(\n\u001b[1;32m 13\u001b[0m \u001b[38;5;241m50\u001b[39m, \u001b[38;5;241m50\u001b[39m\n\u001b[1;32m 14\u001b[0m ),\n\u001b[1;32m 15\u001b[0m origin\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlower\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 16\u001b[0m cmap\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mviridis\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 17\u001b[0m )\n\u001b[1;32m 19\u001b[0m ax1\u001b[38;5;241m.\u001b[39mset_title(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIteration \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtimestep\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 20\u001b[0m ax1\u001b[38;5;241m.\u001b[39maxis(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moff\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"Cell \u001b[0;32mIn[8], line 12\u001b[0m\n\u001b[1;32m 8\u001b[0m ax3 \u001b[38;5;241m=\u001b[39m axes[\u001b[38;5;241m2\u001b[39m, i]\n\u001b[1;32m 9\u001b[0m ax4 \u001b[38;5;241m=\u001b[39m axes[\u001b[38;5;241m3\u001b[39m, i]\n\u001b[1;32m 11\u001b[0m im1 \u001b[38;5;241m=\u001b[39m ax1\u001b[38;5;241m.\u001b[39mimshow(\n\u001b[0;32m---> 12\u001b[0m np\u001b[38;5;241m.\u001b[39marray((X_manual[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCl\u001b[39m\u001b[38;5;124m\"\u001b[39m])[(timestep \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2500\u001b[39m) : (timestep \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2500\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m2500\u001b[39m)])\u001b[38;5;241m.\u001b[39mreshape(\n\u001b[1;32m 13\u001b[0m \u001b[38;5;241m50\u001b[39m, \u001b[38;5;241m50\u001b[39m\n\u001b[1;32m 14\u001b[0m ),\n\u001b[1;32m 15\u001b[0m origin\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlower\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 16\u001b[0m cmap\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mviridis\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 17\u001b[0m )\n\u001b[1;32m 19\u001b[0m ax1\u001b[38;5;241m.\u001b[39mset_title(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIteration \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtimestep\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 20\u001b[0m ax1\u001b[38;5;241m.\u001b[39maxis(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moff\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"\u001b[0;31mNameError\u001b[0m: name 'X_manual' is not defined"
]
},
@ -371,7 +380,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@ -389,7 +398,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@ -417,17 +426,9 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"I0000 00:00:1743168379.083212 16215 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13393 MB memory: -> device: 0, name: NVIDIA A2, pci bus id: 0000:82:00.0, compute capability: 8.6\n"
]
}
],
"outputs": [],
"source": [
"# select model architecture\n",
"model = model_definition(\"large\", len(df_design.columns), len(df_results.columns)) \n",
@ -441,6 +442,7 @@
"h1 = 0.16726490480995826\n",
"h2 = 0.5283208497548787\n",
"h3 = 0.5099528144902471\n",
"h4 = h3\n",
"\n",
"delta = 1.7642791340966357\n",
"\n",
@ -462,6 +464,645 @@
")"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"X_scaled = preprocess.scaler_input.transform(X.loc[:, X.columns != \"Class\"])"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"data_range = tf.convert_to_tensor(\n",
" preprocess.scaler_input.data_range_, dtype=tf.float32)\n",
"min_values = tf.convert_to_tensor(\n",
" preprocess.scaler_input.data_min_, dtype=tf.float32)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>H</th>\n",
" <th>O</th>\n",
" <th>Ba</th>\n",
" <th>Cl</th>\n",
" <th>S</th>\n",
" <th>Sr</th>\n",
" <th>Barite</th>\n",
" <th>Celestite</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>111.012436</td>\n",
" <td>55.508193</td>\n",
" <td>2.041069e-02</td>\n",
" <td>4.082138e-02</td>\n",
" <td>4.938299e-04</td>\n",
" <td>0.000494</td>\n",
" <td>0.000000</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>111.012436</td>\n",
" <td>55.508428</td>\n",
" <td>1.094567e-02</td>\n",
" <td>2.189133e-02</td>\n",
" <td>5.525578e-04</td>\n",
" <td>0.000553</td>\n",
" <td>0.000000</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>111.012436</td>\n",
" <td>55.508692</td>\n",
" <td>2.943745e-04</td>\n",
" <td>5.887491e-04</td>\n",
" <td>6.186462e-04</td>\n",
" <td>0.000619</td>\n",
" <td>0.000000</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>111.012436</td>\n",
" <td>55.508699</td>\n",
" <td>1.091776e-05</td>\n",
" <td>2.183552e-05</td>\n",
" <td>6.204049e-04</td>\n",
" <td>0.000620</td>\n",
" <td>0.000000</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>111.012436</td>\n",
" <td>55.508699</td>\n",
" <td>4.049176e-07</td>\n",
" <td>8.098352e-07</td>\n",
" <td>6.204702e-04</td>\n",
" <td>0.000620</td>\n",
" <td>0.000000</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1509995</th>\n",
" <td>111.012436</td>\n",
" <td>55.506218</td>\n",
" <td>3.829904e-02</td>\n",
" <td>1.505240e-01</td>\n",
" <td>1.480414e-07</td>\n",
" <td>0.036963</td>\n",
" <td>1.003059</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1509996</th>\n",
" <td>111.012436</td>\n",
" <td>55.506218</td>\n",
" <td>5.143522e-02</td>\n",
" <td>1.609874e-01</td>\n",
" <td>1.241651e-07</td>\n",
" <td>0.029059</td>\n",
" <td>1.005189</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1509997</th>\n",
" <td>111.012436</td>\n",
" <td>55.506218</td>\n",
" <td>6.735992e-02</td>\n",
" <td>1.737510e-01</td>\n",
" <td>1.075194e-07</td>\n",
" <td>0.019516</td>\n",
" <td>1.002099</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1509998</th>\n",
" <td>111.012436</td>\n",
" <td>55.506218</td>\n",
" <td>8.775700e-02</td>\n",
" <td>1.901489e-01</td>\n",
" <td>8.510718e-08</td>\n",
" <td>0.007318</td>\n",
" <td>1.000628</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1509999</th>\n",
" <td>111.012436</td>\n",
" <td>55.506218</td>\n",
" <td>9.695043e-02</td>\n",
" <td>1.975455e-01</td>\n",
" <td>7.271897e-08</td>\n",
" <td>0.001822</td>\n",
" <td>1.000276</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1510000 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" H O Ba Cl S \\\n",
"0 111.012436 55.508193 2.041069e-02 4.082138e-02 4.938299e-04 \n",
"1 111.012436 55.508428 1.094567e-02 2.189133e-02 5.525578e-04 \n",
"2 111.012436 55.508692 2.943745e-04 5.887491e-04 6.186462e-04 \n",
"3 111.012436 55.508699 1.091776e-05 2.183552e-05 6.204049e-04 \n",
"4 111.012436 55.508699 4.049176e-07 8.098352e-07 6.204702e-04 \n",
"... ... ... ... ... ... \n",
"1509995 111.012436 55.506218 3.829904e-02 1.505240e-01 1.480414e-07 \n",
"1509996 111.012436 55.506218 5.143522e-02 1.609874e-01 1.241651e-07 \n",
"1509997 111.012436 55.506218 6.735992e-02 1.737510e-01 1.075194e-07 \n",
"1509998 111.012436 55.506218 8.775700e-02 1.901489e-01 8.510718e-08 \n",
"1509999 111.012436 55.506218 9.695043e-02 1.975455e-01 7.271897e-08 \n",
"\n",
" Sr Barite Celestite \n",
"0 0.000494 0.000000 1.0 \n",
"1 0.000553 0.000000 1.0 \n",
"2 0.000619 0.000000 1.0 \n",
"3 0.000620 0.000000 1.0 \n",
"4 0.000620 0.000000 1.0 \n",
"... ... ... ... \n",
"1509995 0.036963 1.003059 0.0 \n",
"1509996 0.029059 1.005189 0.0 \n",
"1509997 0.019516 1.002099 0.0 \n",
"1509998 0.007318 1.000628 0.0 \n",
"1509999 0.001822 1.000276 0.0 \n",
"\n",
"[1510000 rows x 8 columns]"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_scaled * data_range + min_values"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>H</th>\n",
" <th>O</th>\n",
" <th>Ba</th>\n",
" <th>Cl</th>\n",
" <th>S</th>\n",
" <th>Sr</th>\n",
" <th>Barite</th>\n",
" <th>Celestite</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>111.012434</td>\n",
" <td>55.508192</td>\n",
" <td>2.041069e-02</td>\n",
" <td>4.082138e-02</td>\n",
" <td>4.938300e-04</td>\n",
" <td>0.000494</td>\n",
" <td>0.000000</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>111.012434</td>\n",
" <td>55.508427</td>\n",
" <td>1.094567e-02</td>\n",
" <td>2.189133e-02</td>\n",
" <td>5.525578e-04</td>\n",
" <td>0.000553</td>\n",
" <td>0.000000</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>111.012434</td>\n",
" <td>55.508691</td>\n",
" <td>2.943745e-04</td>\n",
" <td>5.887491e-04</td>\n",
" <td>6.186462e-04</td>\n",
" <td>0.000619</td>\n",
" <td>0.000000</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>111.012434</td>\n",
" <td>55.508698</td>\n",
" <td>1.091776e-05</td>\n",
" <td>2.183551e-05</td>\n",
" <td>6.204050e-04</td>\n",
" <td>0.000620</td>\n",
" <td>0.000000</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>111.012434</td>\n",
" <td>55.508699</td>\n",
" <td>4.049176e-07</td>\n",
" <td>8.098352e-07</td>\n",
" <td>6.204702e-04</td>\n",
" <td>0.000620</td>\n",
" <td>0.000000</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1509995</th>\n",
" <td>111.012434</td>\n",
" <td>55.506217</td>\n",
" <td>3.829904e-02</td>\n",
" <td>1.505240e-01</td>\n",
" <td>1.480414e-07</td>\n",
" <td>0.036963</td>\n",
" <td>1.003059</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1509996</th>\n",
" <td>111.012434</td>\n",
" <td>55.506217</td>\n",
" <td>5.143522e-02</td>\n",
" <td>1.609874e-01</td>\n",
" <td>1.241651e-07</td>\n",
" <td>0.029059</td>\n",
" <td>1.005189</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1509997</th>\n",
" <td>111.012434</td>\n",
" <td>55.506217</td>\n",
" <td>6.735993e-02</td>\n",
" <td>1.737510e-01</td>\n",
" <td>1.075194e-07</td>\n",
" <td>0.019516</td>\n",
" <td>1.002099</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1509998</th>\n",
" <td>111.012434</td>\n",
" <td>55.506217</td>\n",
" <td>8.775701e-02</td>\n",
" <td>1.901489e-01</td>\n",
" <td>8.510718e-08</td>\n",
" <td>0.007318</td>\n",
" <td>1.000628</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1509999</th>\n",
" <td>111.012434</td>\n",
" <td>55.506217</td>\n",
" <td>9.695043e-02</td>\n",
" <td>1.975455e-01</td>\n",
" <td>7.271898e-08</td>\n",
" <td>0.001822</td>\n",
" <td>1.000276</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1510000 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" H O Ba Cl S \\\n",
"0 111.012434 55.508192 2.041069e-02 4.082138e-02 4.938300e-04 \n",
"1 111.012434 55.508427 1.094567e-02 2.189133e-02 5.525578e-04 \n",
"2 111.012434 55.508691 2.943745e-04 5.887491e-04 6.186462e-04 \n",
"3 111.012434 55.508698 1.091776e-05 2.183551e-05 6.204050e-04 \n",
"4 111.012434 55.508699 4.049176e-07 8.098352e-07 6.204702e-04 \n",
"... ... ... ... ... ... \n",
"1509995 111.012434 55.506217 3.829904e-02 1.505240e-01 1.480414e-07 \n",
"1509996 111.012434 55.506217 5.143522e-02 1.609874e-01 1.241651e-07 \n",
"1509997 111.012434 55.506217 6.735993e-02 1.737510e-01 1.075194e-07 \n",
"1509998 111.012434 55.506217 8.775701e-02 1.901489e-01 8.510718e-08 \n",
"1509999 111.012434 55.506217 9.695043e-02 1.975455e-01 7.271898e-08 \n",
"\n",
" Sr Barite Celestite \n",
"0 0.000494 0.000000 1.0 \n",
"1 0.000553 0.000000 1.0 \n",
"2 0.000619 0.000000 1.0 \n",
"3 0.000620 0.000000 1.0 \n",
"4 0.000620 0.000000 1.0 \n",
"... ... ... ... \n",
"1509995 0.036963 1.003059 0.0 \n",
"1509996 0.029059 1.005189 0.0 \n",
"1509997 0.019516 1.002099 0.0 \n",
"1509998 0.007318 1.000628 0.0 \n",
"1509999 0.001822 1.000276 0.0 \n",
"\n",
"[1510000 rows x 8 columns]"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.DataFrame(preprocess.scaler_input.inverse_transform(X_scaled), columns=X_scaled.columns)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>H</th>\n",
" <th>O</th>\n",
" <th>Ba</th>\n",
" <th>Cl</th>\n",
" <th>S</th>\n",
" <th>Sr</th>\n",
" <th>Barite</th>\n",
" <th>Celestite</th>\n",
" <th>Class</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>111.012434</td>\n",
" <td>55.506561</td>\n",
" <td>0.000025</td>\n",
" <td>0.070397</td>\n",
" <td>0.000086</td>\n",
" <td>0.035259</td>\n",
" <td>2.732797e-06</td>\n",
" <td>1.000767</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>111.012434</td>\n",
" <td>55.506591</td>\n",
" <td>0.000017</td>\n",
" <td>0.048181</td>\n",
" <td>0.000093</td>\n",
" <td>0.024167</td>\n",
" <td>4.615323e-07</td>\n",
" <td>1.000621</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>111.012434</td>\n",
" <td>55.506539</td>\n",
" <td>0.000036</td>\n",
" <td>0.097734</td>\n",
" <td>0.000080</td>\n",
" <td>0.048912</td>\n",
" <td>1.387101e-04</td>\n",
" <td>1.000649</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>111.012434</td>\n",
" <td>55.506551</td>\n",
" <td>0.000029</td>\n",
" <td>0.080614</td>\n",
" <td>0.000084</td>\n",
" <td>0.040362</td>\n",
" <td>4.606226e-07</td>\n",
" <td>1.000513</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>111.012434</td>\n",
" <td>55.506650</td>\n",
" <td>0.000009</td>\n",
" <td>0.027535</td>\n",
" <td>0.000108</td>\n",
" <td>0.013866</td>\n",
" <td>2.198323e-08</td>\n",
" <td>1.000203</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1071013</th>\n",
" <td>111.012434</td>\n",
" <td>55.506228</td>\n",
" <td>0.011569</td>\n",
" <td>0.122579</td>\n",
" <td>0.000003</td>\n",
" <td>0.049723</td>\n",
" <td>1.003593e+00</td>\n",
" <td>0.000000</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1071014</th>\n",
" <td>111.012434</td>\n",
" <td>55.506625</td>\n",
" <td>0.000012</td>\n",
" <td>0.033830</td>\n",
" <td>0.000102</td>\n",
" <td>0.017005</td>\n",
" <td>3.780744e-07</td>\n",
" <td>1.000677</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1071015</th>\n",
" <td>111.012434</td>\n",
" <td>55.506563</td>\n",
" <td>0.000024</td>\n",
" <td>0.068341</td>\n",
" <td>0.000086</td>\n",
" <td>0.034233</td>\n",
" <td>4.692203e-07</td>\n",
" <td>1.000561</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1071016</th>\n",
" <td>111.012434</td>\n",
" <td>55.506569</td>\n",
" <td>0.000022</td>\n",
" <td>0.062430</td>\n",
" <td>0.000088</td>\n",
" <td>0.031281</td>\n",
" <td>4.191733e-08</td>\n",
" <td>1.000090</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1071017</th>\n",
" <td>111.012434</td>\n",
" <td>55.506734</td>\n",
" <td>0.000005</td>\n",
" <td>0.015954</td>\n",
" <td>0.000129</td>\n",
" <td>0.008101</td>\n",
" <td>0.000000e+00</td>\n",
" <td>1.000410</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1071018 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" H O Ba Cl S Sr \\\n",
"0 111.012434 55.506561 0.000025 0.070397 0.000086 0.035259 \n",
"1 111.012434 55.506591 0.000017 0.048181 0.000093 0.024167 \n",
"2 111.012434 55.506539 0.000036 0.097734 0.000080 0.048912 \n",
"3 111.012434 55.506551 0.000029 0.080614 0.000084 0.040362 \n",
"4 111.012434 55.506650 0.000009 0.027535 0.000108 0.013866 \n",
"... ... ... ... ... ... ... \n",
"1071013 111.012434 55.506228 0.011569 0.122579 0.000003 0.049723 \n",
"1071014 111.012434 55.506625 0.000012 0.033830 0.000102 0.017005 \n",
"1071015 111.012434 55.506563 0.000024 0.068341 0.000086 0.034233 \n",
"1071016 111.012434 55.506569 0.000022 0.062430 0.000088 0.031281 \n",
"1071017 111.012434 55.506734 0.000005 0.015954 0.000129 0.008101 \n",
"\n",
" Barite Celestite Class \n",
"0 2.732797e-06 1.000767 1.0 \n",
"1 4.615323e-07 1.000621 1.0 \n",
"2 1.387101e-04 1.000649 1.0 \n",
"3 4.606226e-07 1.000513 1.0 \n",
"4 2.198323e-08 1.000203 1.0 \n",
"... ... ... ... \n",
"1071013 1.003593e+00 0.000000 1.0 \n",
"1071014 3.780744e-07 1.000677 1.0 \n",
"1071015 4.692203e-07 1.000561 1.0 \n",
"1071016 4.191733e-08 1.000090 1.0 \n",
"1071017 0.000000e+00 1.000410 1.0 \n",
"\n",
"[1071018 rows x 9 columns]"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocess.scale_inverse(X_train)[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
@ -507,7 +1148,7 @@
}
],
"source": [
"history = model_training(model, epochs=500)"
"history = model_training(model, epochs=50)"
]
},
{
@ -957,7 +1598,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "training",
"language": "python",
"name": "python3"
},

View File

@ -1,9 +1,163 @@
from preprocessing import *
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches
import pickle
scaler_experiments = ["none", "minmax", "standard"]
###### Experimental parameters
scaler_type = "minmax"
feature_engineering = False
optimizer_type = "adam"
loss_variant = "huber_mass_balance"
for i in scaler_experiments:
###### load dataset
data_file = h5py.File("../datasets/Barite_4c_mdl.h5")
design = data_file["design"]
results = data_file["result"]
df_design = pd.DataFrame(
np.array(design["data"]).transpose(), columns=np.array(design["names"].asstr())
)
df_results = pd.DataFrame(
np.array(results["data"]).transpose(), columns=np.array(results["names"].asstr())
)
data_file.close()
df_design.drop("Charge", axis=1, inplace=True, errors="ignore")
df_results.drop("Charge", axis=1, inplace=True, errors="ignore")
###### preprocessing
if feature_engineering == True:
df_design["Ba\Sr"] = df_design["Ba"] / df_design["Sr"]
df_design["BaxS"] = df_design["Ba"] * df_design["S"]
preprocess = preprocessing()
X, y = preprocess.cluster_manual(df_design[df_design.columns], df_design[df_results.columns], "Cl")
X_train, X_test, y_train, y_test = preprocess.split(X, y, ratio=0.2)
X_train, y_train = preprocess.balancer(X_train, y_train, strategy="off")
# train only on reactive cells
X_train, y_train = preprocess.class_selection(X_train, y_train, class_label=1.0)
preprocess.scale_fit(X_train, y_train, type=scaler_type)
X_train, X_test, y_train, y_test = preprocess.scale_transform(
X_train, X_test, y_train, y_test
)
X_train, X_val, y_train, y_val = preprocess.split(X_train, y_train, ratio=0.1)
###### create and compile model
def model_training(model, batch_size=512, epochs=100):
start = time.time()
callback = keras.callbacks.EarlyStopping(monitor="loss", patience=30)
history = model.fit(
X_train.loc[:, X_train.columns != "Class"],
y_train.loc[:, y_train.columns != "Class"],
batch_size=batch_size,
epochs=epochs,
validation_data=(
X_val.loc[:, X_val.columns != "Class"],
y_val.loc[:, y_val.columns != "Class"],
),
callbacks=[callback],
)
end = time.time()
print("Training took {} seconds".format(end - start))
return history
# mapping of column names to column index
column_dict = {}
for i in df_results.columns:
column_dict[i] = y.columns.get_loc(i)
# select model architecture
model = model_definition("large", len(df_design.columns), len(df_results.columns))
# define learning rate adaptation
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=0.01, decay_steps=2000, decay_rate=0.9, staircase=True
)
# hyperparameters that are determined by hyperparameter optimization
h1 = 0.16726490480995826
h2 = 0.5283208497548787
h3 = 0.5099528144902471
h4 = h3
delta = 1.7642791340966357
match optimizer_type:
case "adam":
optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)
case "sgd":
optimizer = keras.optimizers.SGD(learning_rate=lr_schedule)
case "rmsprop":
optimizer = keras.optimizers.RMSprop(learning_rate=lr_schedule)
model.compile(
optimizer=optimizer,
loss=custom_loss(preprocess, column_dict, h1, h2, h3, h4, scaler_type, loss_variant, 1),
metrics=[
huber_metric(delta),
mass_balance_metric(preprocess, column_dict, scaler_type, loss_variant),
],
)
###### train model
epochs = 200
history = model_training(model, epochs=epochs)
###### evaluate model
results = mass_balance_evaluation(model, X_test, preprocess)
mass_balance_ratio(results, threshold=1e-5)
def test_model(model, X_test, y_test):
X_test.reset_index(inplace=True, drop=True)
y_test.reset_index(inplace=True, drop=True)
all = model.evaluate(X_test.loc[:, X_test.columns != "Class"], y_test.loc[:, y_test.columns != "Class"])
class_0 = model.evaluate(X_test[X_test["Class"] == 0].iloc[:, X_test.columns != "Class"], y_test[X_test["Class"] == 0].iloc[:, y_test.columns != "Class"])
class_1 = model.evaluate(
X_test[X_test["Class"] == 1].iloc[:, :-1], y_test[X_test["Class"] == 1].iloc[:, :-1])
print("metric all data: ", all)
print("metric class 0: ", class_0)
print("metric class 1: ", class_1)
test_model(model, X_test, y_test)
###### save model and history
delimiter = "_"
idx_string = scaler_type + delimiter + "feature_engineering_" + feature_engineering + delimiter + optimizer_type + delimiter + loss_variant
file_name = "history_" + idx_string
with open('../results/'+file_name, 'wb') as file_pi:
pickle.dump(history.history, file_pi)
model.save_weights("../results/models/model_"+idx_string + ".weights.h5")

View File

@ -126,6 +126,7 @@ def custom_loss(
h1,
h2,
h3,
h4,
scaler_type="minmax",
loss_variant="huber",
delta=1.0,
@ -175,6 +176,8 @@ def custom_loss(
def loss(results, predicted):
# inverse min/max scaling
preprocess.scaler_input(results)
if scaler_type == "minmax":
predicted_inverse = predicted * data_range + min_values
results_inverse = results * data_range + min_values
@ -240,7 +243,7 @@ def custom_loss(
elif loss_variant == "huber_mass_balance":
total_loss = h1 * huber_loss + h2 * dBa + h3 * dSr
elif "huber_mass_balance_extended":
total_loss = h1 * huber_loss + h2 * dBa + h3 * dSr + h3 * dS
total_loss = h1 * huber_loss + h2 * dBa + h3 * dSr + h4 * dS
else:
raise Exception(
"No valid loss variant found. Choose between 'huber' and 'huber_mass_balance'."