mirror of
https://git.gfz-potsdam.de/naaice/model-training.git
synced 2025-12-15 19:58:22 +01:00
add script for experiments
This commit is contained in:
parent
04f5c40b29
commit
4f954cbc84
@ -52,9 +52,21 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "FileNotFoundError",
|
||||
"evalue": "[Errno 2] No such file or directory: './projects/model-training/src/'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mos\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m os\u001b[38;5;241m.\u001b[39mchdir(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m./projects/model-training/src/\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: './projects/model-training/src/'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"os.chdir(\"./projects/model-training/src/\")"
|
||||
@ -69,12 +81,9 @@
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2025-03-28 14:24:13.743271: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
|
||||
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
|
||||
"E0000 00:00:1743168253.897216 16215 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
|
||||
"E0000 00:00:1743168253.941693 16215 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
|
||||
"2025-03-28 14:24:14.439356: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
|
||||
"To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
|
||||
"2025-03-28 14:41:07.244286: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
|
||||
"2025-03-28 14:41:07.369762: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
|
||||
"To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -271,7 +280,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -281,7 +290,7 @@
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[6], line 12\u001b[0m\n\u001b[1;32m 8\u001b[0m ax3 \u001b[38;5;241m=\u001b[39m axes[\u001b[38;5;241m2\u001b[39m, i]\n\u001b[1;32m 9\u001b[0m ax4 \u001b[38;5;241m=\u001b[39m axes[\u001b[38;5;241m3\u001b[39m, i]\n\u001b[1;32m 11\u001b[0m im1 \u001b[38;5;241m=\u001b[39m ax1\u001b[38;5;241m.\u001b[39mimshow(\n\u001b[0;32m---> 12\u001b[0m np\u001b[38;5;241m.\u001b[39marray((X_manual[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCl\u001b[39m\u001b[38;5;124m\"\u001b[39m])[(timestep \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2500\u001b[39m) : (timestep \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2500\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m2500\u001b[39m)])\u001b[38;5;241m.\u001b[39mreshape(\n\u001b[1;32m 13\u001b[0m \u001b[38;5;241m50\u001b[39m, \u001b[38;5;241m50\u001b[39m\n\u001b[1;32m 14\u001b[0m ),\n\u001b[1;32m 15\u001b[0m origin\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlower\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 16\u001b[0m cmap\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mviridis\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 17\u001b[0m )\n\u001b[1;32m 19\u001b[0m ax1\u001b[38;5;241m.\u001b[39mset_title(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIteration \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtimestep\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 20\u001b[0m ax1\u001b[38;5;241m.\u001b[39maxis(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moff\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||
"Cell \u001b[0;32mIn[8], line 12\u001b[0m\n\u001b[1;32m 8\u001b[0m ax3 \u001b[38;5;241m=\u001b[39m axes[\u001b[38;5;241m2\u001b[39m, i]\n\u001b[1;32m 9\u001b[0m ax4 \u001b[38;5;241m=\u001b[39m axes[\u001b[38;5;241m3\u001b[39m, i]\n\u001b[1;32m 11\u001b[0m im1 \u001b[38;5;241m=\u001b[39m ax1\u001b[38;5;241m.\u001b[39mimshow(\n\u001b[0;32m---> 12\u001b[0m np\u001b[38;5;241m.\u001b[39marray((X_manual[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCl\u001b[39m\u001b[38;5;124m\"\u001b[39m])[(timestep \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2500\u001b[39m) : (timestep \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2500\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m2500\u001b[39m)])\u001b[38;5;241m.\u001b[39mreshape(\n\u001b[1;32m 13\u001b[0m \u001b[38;5;241m50\u001b[39m, \u001b[38;5;241m50\u001b[39m\n\u001b[1;32m 14\u001b[0m ),\n\u001b[1;32m 15\u001b[0m origin\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlower\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 16\u001b[0m cmap\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mviridis\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 17\u001b[0m )\n\u001b[1;32m 19\u001b[0m ax1\u001b[38;5;241m.\u001b[39mset_title(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIteration \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtimestep\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 20\u001b[0m ax1\u001b[38;5;241m.\u001b[39maxis(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moff\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||
"\u001b[0;31mNameError\u001b[0m: name 'X_manual' is not defined"
|
||||
]
|
||||
},
|
||||
@ -371,7 +380,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -389,7 +398,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -417,17 +426,9 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"I0000 00:00:1743168379.083212 16215 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13393 MB memory: -> device: 0, name: NVIDIA A2, pci bus id: 0000:82:00.0, compute capability: 8.6\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# select model architecture\n",
|
||||
"model = model_definition(\"large\", len(df_design.columns), len(df_results.columns)) \n",
|
||||
@ -441,6 +442,7 @@
|
||||
"h1 = 0.16726490480995826\n",
|
||||
"h2 = 0.5283208497548787\n",
|
||||
"h3 = 0.5099528144902471\n",
|
||||
"h4 = h3\n",
|
||||
"\n",
|
||||
"delta = 1.7642791340966357\n",
|
||||
"\n",
|
||||
@ -462,6 +464,645 @@
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_scaled = preprocess.scaler_input.transform(X.loc[:, X.columns != \"Class\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data_range = tf.convert_to_tensor(\n",
|
||||
" preprocess.scaler_input.data_range_, dtype=tf.float32)\n",
|
||||
"min_values = tf.convert_to_tensor(\n",
|
||||
" preprocess.scaler_input.data_min_, dtype=tf.float32)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>H</th>\n",
|
||||
" <th>O</th>\n",
|
||||
" <th>Ba</th>\n",
|
||||
" <th>Cl</th>\n",
|
||||
" <th>S</th>\n",
|
||||
" <th>Sr</th>\n",
|
||||
" <th>Barite</th>\n",
|
||||
" <th>Celestite</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>111.012436</td>\n",
|
||||
" <td>55.508193</td>\n",
|
||||
" <td>2.041069e-02</td>\n",
|
||||
" <td>4.082138e-02</td>\n",
|
||||
" <td>4.938299e-04</td>\n",
|
||||
" <td>0.000494</td>\n",
|
||||
" <td>0.000000</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>111.012436</td>\n",
|
||||
" <td>55.508428</td>\n",
|
||||
" <td>1.094567e-02</td>\n",
|
||||
" <td>2.189133e-02</td>\n",
|
||||
" <td>5.525578e-04</td>\n",
|
||||
" <td>0.000553</td>\n",
|
||||
" <td>0.000000</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>111.012436</td>\n",
|
||||
" <td>55.508692</td>\n",
|
||||
" <td>2.943745e-04</td>\n",
|
||||
" <td>5.887491e-04</td>\n",
|
||||
" <td>6.186462e-04</td>\n",
|
||||
" <td>0.000619</td>\n",
|
||||
" <td>0.000000</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>111.012436</td>\n",
|
||||
" <td>55.508699</td>\n",
|
||||
" <td>1.091776e-05</td>\n",
|
||||
" <td>2.183552e-05</td>\n",
|
||||
" <td>6.204049e-04</td>\n",
|
||||
" <td>0.000620</td>\n",
|
||||
" <td>0.000000</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>111.012436</td>\n",
|
||||
" <td>55.508699</td>\n",
|
||||
" <td>4.049176e-07</td>\n",
|
||||
" <td>8.098352e-07</td>\n",
|
||||
" <td>6.204702e-04</td>\n",
|
||||
" <td>0.000620</td>\n",
|
||||
" <td>0.000000</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>...</th>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1509995</th>\n",
|
||||
" <td>111.012436</td>\n",
|
||||
" <td>55.506218</td>\n",
|
||||
" <td>3.829904e-02</td>\n",
|
||||
" <td>1.505240e-01</td>\n",
|
||||
" <td>1.480414e-07</td>\n",
|
||||
" <td>0.036963</td>\n",
|
||||
" <td>1.003059</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1509996</th>\n",
|
||||
" <td>111.012436</td>\n",
|
||||
" <td>55.506218</td>\n",
|
||||
" <td>5.143522e-02</td>\n",
|
||||
" <td>1.609874e-01</td>\n",
|
||||
" <td>1.241651e-07</td>\n",
|
||||
" <td>0.029059</td>\n",
|
||||
" <td>1.005189</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1509997</th>\n",
|
||||
" <td>111.012436</td>\n",
|
||||
" <td>55.506218</td>\n",
|
||||
" <td>6.735992e-02</td>\n",
|
||||
" <td>1.737510e-01</td>\n",
|
||||
" <td>1.075194e-07</td>\n",
|
||||
" <td>0.019516</td>\n",
|
||||
" <td>1.002099</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1509998</th>\n",
|
||||
" <td>111.012436</td>\n",
|
||||
" <td>55.506218</td>\n",
|
||||
" <td>8.775700e-02</td>\n",
|
||||
" <td>1.901489e-01</td>\n",
|
||||
" <td>8.510718e-08</td>\n",
|
||||
" <td>0.007318</td>\n",
|
||||
" <td>1.000628</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1509999</th>\n",
|
||||
" <td>111.012436</td>\n",
|
||||
" <td>55.506218</td>\n",
|
||||
" <td>9.695043e-02</td>\n",
|
||||
" <td>1.975455e-01</td>\n",
|
||||
" <td>7.271897e-08</td>\n",
|
||||
" <td>0.001822</td>\n",
|
||||
" <td>1.000276</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"<p>1510000 rows × 8 columns</p>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" H O Ba Cl S \\\n",
|
||||
"0 111.012436 55.508193 2.041069e-02 4.082138e-02 4.938299e-04 \n",
|
||||
"1 111.012436 55.508428 1.094567e-02 2.189133e-02 5.525578e-04 \n",
|
||||
"2 111.012436 55.508692 2.943745e-04 5.887491e-04 6.186462e-04 \n",
|
||||
"3 111.012436 55.508699 1.091776e-05 2.183552e-05 6.204049e-04 \n",
|
||||
"4 111.012436 55.508699 4.049176e-07 8.098352e-07 6.204702e-04 \n",
|
||||
"... ... ... ... ... ... \n",
|
||||
"1509995 111.012436 55.506218 3.829904e-02 1.505240e-01 1.480414e-07 \n",
|
||||
"1509996 111.012436 55.506218 5.143522e-02 1.609874e-01 1.241651e-07 \n",
|
||||
"1509997 111.012436 55.506218 6.735992e-02 1.737510e-01 1.075194e-07 \n",
|
||||
"1509998 111.012436 55.506218 8.775700e-02 1.901489e-01 8.510718e-08 \n",
|
||||
"1509999 111.012436 55.506218 9.695043e-02 1.975455e-01 7.271897e-08 \n",
|
||||
"\n",
|
||||
" Sr Barite Celestite \n",
|
||||
"0 0.000494 0.000000 1.0 \n",
|
||||
"1 0.000553 0.000000 1.0 \n",
|
||||
"2 0.000619 0.000000 1.0 \n",
|
||||
"3 0.000620 0.000000 1.0 \n",
|
||||
"4 0.000620 0.000000 1.0 \n",
|
||||
"... ... ... ... \n",
|
||||
"1509995 0.036963 1.003059 0.0 \n",
|
||||
"1509996 0.029059 1.005189 0.0 \n",
|
||||
"1509997 0.019516 1.002099 0.0 \n",
|
||||
"1509998 0.007318 1.000628 0.0 \n",
|
||||
"1509999 0.001822 1.000276 0.0 \n",
|
||||
"\n",
|
||||
"[1510000 rows x 8 columns]"
|
||||
]
|
||||
},
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"X_scaled * data_range + min_values"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>H</th>\n",
|
||||
" <th>O</th>\n",
|
||||
" <th>Ba</th>\n",
|
||||
" <th>Cl</th>\n",
|
||||
" <th>S</th>\n",
|
||||
" <th>Sr</th>\n",
|
||||
" <th>Barite</th>\n",
|
||||
" <th>Celestite</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>111.012434</td>\n",
|
||||
" <td>55.508192</td>\n",
|
||||
" <td>2.041069e-02</td>\n",
|
||||
" <td>4.082138e-02</td>\n",
|
||||
" <td>4.938300e-04</td>\n",
|
||||
" <td>0.000494</td>\n",
|
||||
" <td>0.000000</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>111.012434</td>\n",
|
||||
" <td>55.508427</td>\n",
|
||||
" <td>1.094567e-02</td>\n",
|
||||
" <td>2.189133e-02</td>\n",
|
||||
" <td>5.525578e-04</td>\n",
|
||||
" <td>0.000553</td>\n",
|
||||
" <td>0.000000</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>111.012434</td>\n",
|
||||
" <td>55.508691</td>\n",
|
||||
" <td>2.943745e-04</td>\n",
|
||||
" <td>5.887491e-04</td>\n",
|
||||
" <td>6.186462e-04</td>\n",
|
||||
" <td>0.000619</td>\n",
|
||||
" <td>0.000000</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>111.012434</td>\n",
|
||||
" <td>55.508698</td>\n",
|
||||
" <td>1.091776e-05</td>\n",
|
||||
" <td>2.183551e-05</td>\n",
|
||||
" <td>6.204050e-04</td>\n",
|
||||
" <td>0.000620</td>\n",
|
||||
" <td>0.000000</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>111.012434</td>\n",
|
||||
" <td>55.508699</td>\n",
|
||||
" <td>4.049176e-07</td>\n",
|
||||
" <td>8.098352e-07</td>\n",
|
||||
" <td>6.204702e-04</td>\n",
|
||||
" <td>0.000620</td>\n",
|
||||
" <td>0.000000</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>...</th>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1509995</th>\n",
|
||||
" <td>111.012434</td>\n",
|
||||
" <td>55.506217</td>\n",
|
||||
" <td>3.829904e-02</td>\n",
|
||||
" <td>1.505240e-01</td>\n",
|
||||
" <td>1.480414e-07</td>\n",
|
||||
" <td>0.036963</td>\n",
|
||||
" <td>1.003059</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1509996</th>\n",
|
||||
" <td>111.012434</td>\n",
|
||||
" <td>55.506217</td>\n",
|
||||
" <td>5.143522e-02</td>\n",
|
||||
" <td>1.609874e-01</td>\n",
|
||||
" <td>1.241651e-07</td>\n",
|
||||
" <td>0.029059</td>\n",
|
||||
" <td>1.005189</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1509997</th>\n",
|
||||
" <td>111.012434</td>\n",
|
||||
" <td>55.506217</td>\n",
|
||||
" <td>6.735993e-02</td>\n",
|
||||
" <td>1.737510e-01</td>\n",
|
||||
" <td>1.075194e-07</td>\n",
|
||||
" <td>0.019516</td>\n",
|
||||
" <td>1.002099</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1509998</th>\n",
|
||||
" <td>111.012434</td>\n",
|
||||
" <td>55.506217</td>\n",
|
||||
" <td>8.775701e-02</td>\n",
|
||||
" <td>1.901489e-01</td>\n",
|
||||
" <td>8.510718e-08</td>\n",
|
||||
" <td>0.007318</td>\n",
|
||||
" <td>1.000628</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1509999</th>\n",
|
||||
" <td>111.012434</td>\n",
|
||||
" <td>55.506217</td>\n",
|
||||
" <td>9.695043e-02</td>\n",
|
||||
" <td>1.975455e-01</td>\n",
|
||||
" <td>7.271898e-08</td>\n",
|
||||
" <td>0.001822</td>\n",
|
||||
" <td>1.000276</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"<p>1510000 rows × 8 columns</p>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" H O Ba Cl S \\\n",
|
||||
"0 111.012434 55.508192 2.041069e-02 4.082138e-02 4.938300e-04 \n",
|
||||
"1 111.012434 55.508427 1.094567e-02 2.189133e-02 5.525578e-04 \n",
|
||||
"2 111.012434 55.508691 2.943745e-04 5.887491e-04 6.186462e-04 \n",
|
||||
"3 111.012434 55.508698 1.091776e-05 2.183551e-05 6.204050e-04 \n",
|
||||
"4 111.012434 55.508699 4.049176e-07 8.098352e-07 6.204702e-04 \n",
|
||||
"... ... ... ... ... ... \n",
|
||||
"1509995 111.012434 55.506217 3.829904e-02 1.505240e-01 1.480414e-07 \n",
|
||||
"1509996 111.012434 55.506217 5.143522e-02 1.609874e-01 1.241651e-07 \n",
|
||||
"1509997 111.012434 55.506217 6.735993e-02 1.737510e-01 1.075194e-07 \n",
|
||||
"1509998 111.012434 55.506217 8.775701e-02 1.901489e-01 8.510718e-08 \n",
|
||||
"1509999 111.012434 55.506217 9.695043e-02 1.975455e-01 7.271898e-08 \n",
|
||||
"\n",
|
||||
" Sr Barite Celestite \n",
|
||||
"0 0.000494 0.000000 1.0 \n",
|
||||
"1 0.000553 0.000000 1.0 \n",
|
||||
"2 0.000619 0.000000 1.0 \n",
|
||||
"3 0.000620 0.000000 1.0 \n",
|
||||
"4 0.000620 0.000000 1.0 \n",
|
||||
"... ... ... ... \n",
|
||||
"1509995 0.036963 1.003059 0.0 \n",
|
||||
"1509996 0.029059 1.005189 0.0 \n",
|
||||
"1509997 0.019516 1.002099 0.0 \n",
|
||||
"1509998 0.007318 1.000628 0.0 \n",
|
||||
"1509999 0.001822 1.000276 0.0 \n",
|
||||
"\n",
|
||||
"[1510000 rows x 8 columns]"
|
||||
]
|
||||
},
|
||||
"execution_count": 40,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pd.DataFrame(preprocess.scaler_input.inverse_transform(X_scaled), columns=X_scaled.columns)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>H</th>\n",
|
||||
" <th>O</th>\n",
|
||||
" <th>Ba</th>\n",
|
||||
" <th>Cl</th>\n",
|
||||
" <th>S</th>\n",
|
||||
" <th>Sr</th>\n",
|
||||
" <th>Barite</th>\n",
|
||||
" <th>Celestite</th>\n",
|
||||
" <th>Class</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>111.012434</td>\n",
|
||||
" <td>55.506561</td>\n",
|
||||
" <td>0.000025</td>\n",
|
||||
" <td>0.070397</td>\n",
|
||||
" <td>0.000086</td>\n",
|
||||
" <td>0.035259</td>\n",
|
||||
" <td>2.732797e-06</td>\n",
|
||||
" <td>1.000767</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>111.012434</td>\n",
|
||||
" <td>55.506591</td>\n",
|
||||
" <td>0.000017</td>\n",
|
||||
" <td>0.048181</td>\n",
|
||||
" <td>0.000093</td>\n",
|
||||
" <td>0.024167</td>\n",
|
||||
" <td>4.615323e-07</td>\n",
|
||||
" <td>1.000621</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>111.012434</td>\n",
|
||||
" <td>55.506539</td>\n",
|
||||
" <td>0.000036</td>\n",
|
||||
" <td>0.097734</td>\n",
|
||||
" <td>0.000080</td>\n",
|
||||
" <td>0.048912</td>\n",
|
||||
" <td>1.387101e-04</td>\n",
|
||||
" <td>1.000649</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>111.012434</td>\n",
|
||||
" <td>55.506551</td>\n",
|
||||
" <td>0.000029</td>\n",
|
||||
" <td>0.080614</td>\n",
|
||||
" <td>0.000084</td>\n",
|
||||
" <td>0.040362</td>\n",
|
||||
" <td>4.606226e-07</td>\n",
|
||||
" <td>1.000513</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>111.012434</td>\n",
|
||||
" <td>55.506650</td>\n",
|
||||
" <td>0.000009</td>\n",
|
||||
" <td>0.027535</td>\n",
|
||||
" <td>0.000108</td>\n",
|
||||
" <td>0.013866</td>\n",
|
||||
" <td>2.198323e-08</td>\n",
|
||||
" <td>1.000203</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>...</th>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1071013</th>\n",
|
||||
" <td>111.012434</td>\n",
|
||||
" <td>55.506228</td>\n",
|
||||
" <td>0.011569</td>\n",
|
||||
" <td>0.122579</td>\n",
|
||||
" <td>0.000003</td>\n",
|
||||
" <td>0.049723</td>\n",
|
||||
" <td>1.003593e+00</td>\n",
|
||||
" <td>0.000000</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1071014</th>\n",
|
||||
" <td>111.012434</td>\n",
|
||||
" <td>55.506625</td>\n",
|
||||
" <td>0.000012</td>\n",
|
||||
" <td>0.033830</td>\n",
|
||||
" <td>0.000102</td>\n",
|
||||
" <td>0.017005</td>\n",
|
||||
" <td>3.780744e-07</td>\n",
|
||||
" <td>1.000677</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1071015</th>\n",
|
||||
" <td>111.012434</td>\n",
|
||||
" <td>55.506563</td>\n",
|
||||
" <td>0.000024</td>\n",
|
||||
" <td>0.068341</td>\n",
|
||||
" <td>0.000086</td>\n",
|
||||
" <td>0.034233</td>\n",
|
||||
" <td>4.692203e-07</td>\n",
|
||||
" <td>1.000561</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1071016</th>\n",
|
||||
" <td>111.012434</td>\n",
|
||||
" <td>55.506569</td>\n",
|
||||
" <td>0.000022</td>\n",
|
||||
" <td>0.062430</td>\n",
|
||||
" <td>0.000088</td>\n",
|
||||
" <td>0.031281</td>\n",
|
||||
" <td>4.191733e-08</td>\n",
|
||||
" <td>1.000090</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1071017</th>\n",
|
||||
" <td>111.012434</td>\n",
|
||||
" <td>55.506734</td>\n",
|
||||
" <td>0.000005</td>\n",
|
||||
" <td>0.015954</td>\n",
|
||||
" <td>0.000129</td>\n",
|
||||
" <td>0.008101</td>\n",
|
||||
" <td>0.000000e+00</td>\n",
|
||||
" <td>1.000410</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"<p>1071018 rows × 9 columns</p>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" H O Ba Cl S Sr \\\n",
|
||||
"0 111.012434 55.506561 0.000025 0.070397 0.000086 0.035259 \n",
|
||||
"1 111.012434 55.506591 0.000017 0.048181 0.000093 0.024167 \n",
|
||||
"2 111.012434 55.506539 0.000036 0.097734 0.000080 0.048912 \n",
|
||||
"3 111.012434 55.506551 0.000029 0.080614 0.000084 0.040362 \n",
|
||||
"4 111.012434 55.506650 0.000009 0.027535 0.000108 0.013866 \n",
|
||||
"... ... ... ... ... ... ... \n",
|
||||
"1071013 111.012434 55.506228 0.011569 0.122579 0.000003 0.049723 \n",
|
||||
"1071014 111.012434 55.506625 0.000012 0.033830 0.000102 0.017005 \n",
|
||||
"1071015 111.012434 55.506563 0.000024 0.068341 0.000086 0.034233 \n",
|
||||
"1071016 111.012434 55.506569 0.000022 0.062430 0.000088 0.031281 \n",
|
||||
"1071017 111.012434 55.506734 0.000005 0.015954 0.000129 0.008101 \n",
|
||||
"\n",
|
||||
" Barite Celestite Class \n",
|
||||
"0 2.732797e-06 1.000767 1.0 \n",
|
||||
"1 4.615323e-07 1.000621 1.0 \n",
|
||||
"2 1.387101e-04 1.000649 1.0 \n",
|
||||
"3 4.606226e-07 1.000513 1.0 \n",
|
||||
"4 2.198323e-08 1.000203 1.0 \n",
|
||||
"... ... ... ... \n",
|
||||
"1071013 1.003593e+00 0.000000 1.0 \n",
|
||||
"1071014 3.780744e-07 1.000677 1.0 \n",
|
||||
"1071015 4.692203e-07 1.000561 1.0 \n",
|
||||
"1071016 4.191733e-08 1.000090 1.0 \n",
|
||||
"1071017 0.000000e+00 1.000410 1.0 \n",
|
||||
"\n",
|
||||
"[1071018 rows x 9 columns]"
|
||||
]
|
||||
},
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"preprocess.scale_inverse(X_train)[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@ -507,7 +1148,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"history = model_training(model, epochs=500)"
|
||||
"history = model_training(model, epochs=50)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -957,7 +1598,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": "training",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
||||
@ -1,9 +1,163 @@
|
||||
from preprocessing import *
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.colors as mcolors
|
||||
import matplotlib.patches as mpatches
|
||||
import pickle
|
||||
|
||||
|
||||
scaler_experiments = ["none", "minmax", "standard"]
|
||||
###### Experimental parameters
|
||||
scaler_type = "minmax"
|
||||
feature_engineering = False
|
||||
optimizer_type = "adam"
|
||||
loss_variant = "huber_mass_balance"
|
||||
|
||||
|
||||
for i in scaler_experiments:
|
||||
|
||||
###### load dataset
|
||||
data_file = h5py.File("../datasets/Barite_4c_mdl.h5")
|
||||
|
||||
design = data_file["design"]
|
||||
results = data_file["result"]
|
||||
|
||||
df_design = pd.DataFrame(
|
||||
np.array(design["data"]).transpose(), columns=np.array(design["names"].asstr())
|
||||
)
|
||||
df_results = pd.DataFrame(
|
||||
np.array(results["data"]).transpose(), columns=np.array(results["names"].asstr())
|
||||
)
|
||||
|
||||
data_file.close()
|
||||
df_design.drop("Charge", axis=1, inplace=True, errors="ignore")
|
||||
df_results.drop("Charge", axis=1, inplace=True, errors="ignore")
|
||||
|
||||
|
||||
###### preprocessing
|
||||
|
||||
if feature_engineering == True:
|
||||
df_design["Ba\Sr"] = df_design["Ba"] / df_design["Sr"]
|
||||
df_design["BaxS"] = df_design["Ba"] * df_design["S"]
|
||||
|
||||
preprocess = preprocessing()
|
||||
X, y = preprocess.cluster_manual(df_design[df_design.columns], df_design[df_results.columns], "Cl")
|
||||
|
||||
X_train, X_test, y_train, y_test = preprocess.split(X, y, ratio=0.2)
|
||||
X_train, y_train = preprocess.balancer(X_train, y_train, strategy="off")
|
||||
|
||||
# train only on reactive cells
|
||||
X_train, y_train = preprocess.class_selection(X_train, y_train, class_label=1.0)
|
||||
|
||||
preprocess.scale_fit(X_train, y_train, type=scaler_type)
|
||||
X_train, X_test, y_train, y_test = preprocess.scale_transform(
|
||||
X_train, X_test, y_train, y_test
|
||||
)
|
||||
X_train, X_val, y_train, y_val = preprocess.split(X_train, y_train, ratio=0.1)
|
||||
|
||||
|
||||
###### create and compile model
|
||||
|
||||
|
||||
def model_training(model, batch_size=512, epochs=100):
|
||||
start = time.time()
|
||||
callback = keras.callbacks.EarlyStopping(monitor="loss", patience=30)
|
||||
history = model.fit(
|
||||
X_train.loc[:, X_train.columns != "Class"],
|
||||
y_train.loc[:, y_train.columns != "Class"],
|
||||
batch_size=batch_size,
|
||||
epochs=epochs,
|
||||
validation_data=(
|
||||
X_val.loc[:, X_val.columns != "Class"],
|
||||
y_val.loc[:, y_val.columns != "Class"],
|
||||
),
|
||||
callbacks=[callback],
|
||||
)
|
||||
|
||||
end = time.time()
|
||||
|
||||
print("Training took {} seconds".format(end - start))
|
||||
|
||||
return history
|
||||
|
||||
|
||||
|
||||
# mapping of column names to column index
|
||||
column_dict = {}
|
||||
for i in df_results.columns:
|
||||
column_dict[i] = y.columns.get_loc(i)
|
||||
|
||||
|
||||
# select model architecture
|
||||
model = model_definition("large", len(df_design.columns), len(df_results.columns))
|
||||
|
||||
# define learning rate adaptation
|
||||
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
|
||||
initial_learning_rate=0.01, decay_steps=2000, decay_rate=0.9, staircase=True
|
||||
)
|
||||
|
||||
# hyperparameters that are determined by hyperparameter optimization
|
||||
h1 = 0.16726490480995826
|
||||
h2 = 0.5283208497548787
|
||||
h3 = 0.5099528144902471
|
||||
h4 = h3
|
||||
|
||||
delta = 1.7642791340966357
|
||||
|
||||
match optimizer_type:
|
||||
case "adam":
|
||||
optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)
|
||||
case "sgd":
|
||||
optimizer = keras.optimizers.SGD(learning_rate=lr_schedule)
|
||||
case "rmsprop":
|
||||
optimizer = keras.optimizers.RMSprop(learning_rate=lr_schedule)
|
||||
|
||||
model.compile(
|
||||
optimizer=optimizer,
|
||||
loss=custom_loss(preprocess, column_dict, h1, h2, h3, h4, scaler_type, loss_variant, 1),
|
||||
metrics=[
|
||||
huber_metric(delta),
|
||||
mass_balance_metric(preprocess, column_dict, scaler_type, loss_variant),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
|
||||
###### train model
|
||||
|
||||
epochs = 200
|
||||
|
||||
history = model_training(model, epochs=epochs)
|
||||
|
||||
|
||||
###### evaluate model
|
||||
|
||||
|
||||
results = mass_balance_evaluation(model, X_test, preprocess)
|
||||
mass_balance_ratio(results, threshold=1e-5)
|
||||
|
||||
def test_model(model, X_test, y_test):
|
||||
|
||||
X_test.reset_index(inplace=True, drop=True)
|
||||
y_test.reset_index(inplace=True, drop=True)
|
||||
all = model.evaluate(X_test.loc[:, X_test.columns != "Class"], y_test.loc[:, y_test.columns != "Class"])
|
||||
class_0 = model.evaluate(X_test[X_test["Class"] == 0].iloc[:, X_test.columns != "Class"], y_test[X_test["Class"] == 0].iloc[:, y_test.columns != "Class"])
|
||||
class_1 = model.evaluate(
|
||||
X_test[X_test["Class"] == 1].iloc[:, :-1], y_test[X_test["Class"] == 1].iloc[:, :-1])
|
||||
|
||||
print("metric all data: ", all)
|
||||
print("metric class 0: ", class_0)
|
||||
print("metric class 1: ", class_1)
|
||||
|
||||
|
||||
test_model(model, X_test, y_test)
|
||||
|
||||
|
||||
|
||||
|
||||
###### save model and history
|
||||
|
||||
delimiter = "_"
|
||||
idx_string = scaler_type + delimiter + "feature_engineering_" + feature_engineering + delimiter + optimizer_type + delimiter + loss_variant
|
||||
file_name = "history_" + idx_string
|
||||
with open('../results/'+file_name, 'wb') as file_pi:
|
||||
pickle.dump(history.history, file_pi)
|
||||
|
||||
model.save_weights("../results/models/model_"+idx_string + ".weights.h5")
|
||||
@ -126,6 +126,7 @@ def custom_loss(
|
||||
h1,
|
||||
h2,
|
||||
h3,
|
||||
h4,
|
||||
scaler_type="minmax",
|
||||
loss_variant="huber",
|
||||
delta=1.0,
|
||||
@ -175,6 +176,8 @@ def custom_loss(
|
||||
|
||||
def loss(results, predicted):
|
||||
# inverse min/max scaling
|
||||
preprocess.scaler_input(results)
|
||||
|
||||
if scaler_type == "minmax":
|
||||
predicted_inverse = predicted * data_range + min_values
|
||||
results_inverse = results * data_range + min_values
|
||||
@ -240,7 +243,7 @@ def custom_loss(
|
||||
elif loss_variant == "huber_mass_balance":
|
||||
total_loss = h1 * huber_loss + h2 * dBa + h3 * dSr
|
||||
elif "huber_mass_balance_extended":
|
||||
total_loss = h1 * huber_loss + h2 * dBa + h3 * dSr + h3 * dS
|
||||
total_loss = h1 * huber_loss + h2 * dBa + h3 * dSr + h4 * dS
|
||||
else:
|
||||
raise Exception(
|
||||
"No valid loss variant found. Choose between 'huber' and 'huber_mass_balance'."
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user