diff --git a/POET_Training.ipynb b/POET_Training.ipynb index e278521..ffb8abb 100644 --- a/POET_Training.ipynb +++ b/POET_Training.ipynb @@ -30,6 +30,15 @@ "execution_count": 1, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-02-17 10:30:47.780794: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2025-02-17 10:30:47.804086: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" + ] + }, { "name": "stdout", "output_type": "stream", @@ -140,11 +149,11 @@ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
        "┃ Layer (type)                     Output Shape                  Param # ┃\n",
        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ dense (Dense)                   │ (None, 128)            │         1,280 │\n",
+       "│ dense (Dense)                   │ (None, 128)            │         1,152 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
        "│ dense_1 (Dense)                 │ (None, 128)            │        16,512 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_2 (Dense)                 │ (None, 9)              │         1,161 │\n",
+       "│ dense_2 (Dense)                 │ (None, 8)              │         1,032 │\n",
        "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
        "
\n" ], @@ -152,11 +161,11 @@ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m1,280\u001b[0m │\n", + "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m1,152\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m16,512\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_2 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m1,161\u001b[0m │\n", + "│ dense_2 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m8\u001b[0m) │ \u001b[38;5;34m1,032\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, @@ -166,11 +175,11 @@ { "data": { "text/html": [ - "
 Total params: 18,953 (74.04 KB)\n",
+       "
 Total params: 18,696 (73.03 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m18,953\u001b[0m (74.04 KB)\n" + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m18,696\u001b[0m (73.03 KB)\n" ] }, "metadata": {}, @@ -179,11 +188,11 @@ { "data": { "text/html": [ - "
 Trainable params: 18,953 (74.04 KB)\n",
+       "
 Trainable params: 18,696 (73.03 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m18,953\u001b[0m (74.04 KB)\n" + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m18,696\u001b[0m (73.03 KB)\n" ] }, "metadata": {}, @@ -207,11 +216,11 @@ "# small model\n", "model_simple = keras.Sequential(\n", " [\n", - " keras.Input(shape = (9,), dtype = \"float32\"),\n", + " keras.Input(shape = (8,), dtype = \"float32\"),\n", " keras.layers.Dense(units = 128, activation = \"linear\", dtype = \"float32\"),\n", " # Dropout(0.2),\n", " keras.layers.Dense(units = 128, activation = \"elu\", dtype = \"float32\"),\n", - " keras.layers.Dense(units = 9, dtype = \"float32\")\n", + " keras.layers.Dense(units = 8, dtype = \"float32\")\n", " ]\n", ")\n", "\n", @@ -221,17 +230,17 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
Model: \"sequential_1\"\n",
+       "
Model: \"sequential_5\"\n",
        "
\n" ], "text/plain": [ - "\u001b[1mModel: \"sequential_1\"\u001b[0m\n" + "\u001b[1mModel: \"sequential_5\"\u001b[0m\n" ] }, "metadata": {}, @@ -243,13 +252,13 @@ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
        "┃ Layer (type)                     Output Shape                  Param # ┃\n",
        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ dense_3 (Dense)                 │ (None, 512)            │         5,120 │\n",
+       "│ dense_21 (Dense)                │ (None, 512)            │         4,608 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_4 (Dense)                 │ (None, 1024)           │       525,312 │\n",
+       "│ dense_22 (Dense)                │ (None, 1024)           │       525,312 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_5 (Dense)                 │ (None, 512)            │       524,800 │\n",
+       "│ dense_23 (Dense)                │ (None, 512)            │       524,800 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_6 (Dense)                 │ (None, 9)              │         4,617 │\n",
+       "│ dense_24 (Dense)                │ (None, 8)              │         4,104 │\n",
        "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
        "
\n" ], @@ -257,13 +266,13 @@ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ dense_3 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m5,120\u001b[0m │\n", + "│ dense_21 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m4,608\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_4 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1024\u001b[0m) │ \u001b[38;5;34m525,312\u001b[0m │\n", + "│ dense_22 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1024\u001b[0m) │ \u001b[38;5;34m525,312\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_5 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m524,800\u001b[0m │\n", + "│ dense_23 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m524,800\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_6 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m4,617\u001b[0m │\n", + "│ dense_24 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m8\u001b[0m) │ \u001b[38;5;34m4,104\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, @@ -273,11 +282,11 @@ { "data": { "text/html": [ - "
 Total params: 1,059,849 (4.04 MB)\n",
+       "
 Total params: 1,058,824 (4.04 MB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,059,849\u001b[0m (4.04 MB)\n" + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m1,058,824\u001b[0m (4.04 MB)\n" ] }, "metadata": {}, @@ -286,11 +295,11 @@ { "data": { "text/html": [ - "
 Trainable params: 1,059,849 (4.04 MB)\n",
+       "
 Trainable params: 1,058,824 (4.04 MB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,059,849\u001b[0m (4.04 MB)\n" + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m1,058,824\u001b[0m (4.04 MB)\n" ] }, "metadata": {}, @@ -313,11 +322,11 @@ "source": [ "# large model\n", "model_large = keras.Sequential(\n", - " [keras.layers.Input(shape=(9,), dtype=dtype),\n", + " [keras.layers.Input(shape=(8,), dtype=dtype),\n", " keras.layers.Dense(512, activation='relu', dtype=dtype),\n", " keras.layers.Dense(1024, activation='relu', dtype=dtype),\n", " keras.layers.Dense(512, activation='relu', dtype=dtype),\n", - " keras.layers.Dense(9, dtype=dtype)\n", + " keras.layers.Dense(8, dtype=dtype)\n", " ])\n", "\n", "model_large.compile(optimizer=optimizer_large, loss = loss)\n", @@ -326,17 +335,17 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
Model: \"sequential_2\"\n",
+       "
Model: \"sequential_4\"\n",
        "
\n" ], "text/plain": [ - "\u001b[1mModel: \"sequential_2\"\u001b[0m\n" + "\u001b[1mModel: \"sequential_4\"\u001b[0m\n" ] }, "metadata": {}, @@ -348,15 +357,15 @@ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
        "┃ Layer (type)                     Output Shape                  Param # ┃\n",
        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ dense_7 (Dense)                 │ (None, 128)            │         1,664 │\n",
+       "│ dense_16 (Dense)                │ (None, 128)            │         1,152 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_8 (Dense)                 │ (None, 256)            │        33,024 │\n",
+       "│ dense_17 (Dense)                │ (None, 256)            │        33,024 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_9 (Dense)                 │ (None, 512)            │       131,584 │\n",
+       "│ dense_18 (Dense)                │ (None, 512)            │       131,584 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_10 (Dense)                │ (None, 256)            │       131,328 │\n",
+       "│ dense_19 (Dense)                │ (None, 256)            │       131,328 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_11 (Dense)                │ (None, 12)             │         3,084 │\n",
+       "│ dense_20 (Dense)                │ (None, 8)              │         2,056 │\n",
        "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
        "
\n" ], @@ -364,15 +373,15 @@ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ dense_7 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m1,664\u001b[0m │\n", + "│ dense_16 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m1,152\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_8 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m33,024\u001b[0m │\n", + "│ dense_17 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m33,024\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_9 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m131,584\u001b[0m │\n", + "│ dense_18 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m) │ \u001b[38;5;34m131,584\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_10 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m131,328\u001b[0m │\n", + "│ dense_19 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m131,328\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_11 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m3,084\u001b[0m │\n", + "│ dense_20 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m8\u001b[0m) │ \u001b[38;5;34m2,056\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, @@ -382,11 +391,11 @@ { "data": { "text/html": [ - "
 Total params: 300,684 (1.15 MB)\n",
+       "
 Total params: 299,144 (1.14 MB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m300,684\u001b[0m (1.15 MB)\n" + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m299,144\u001b[0m (1.14 MB)\n" ] }, "metadata": {}, @@ -395,11 +404,11 @@ { "data": { "text/html": [ - "
 Trainable params: 300,684 (1.15 MB)\n",
+       "
 Trainable params: 299,144 (1.14 MB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m300,684\u001b[0m (1.15 MB)\n" + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m299,144\u001b[0m (1.14 MB)\n" ] }, "metadata": {}, @@ -423,12 +432,12 @@ "# model from paper\n", "# (see https://doi.org/10.1007/s11242-022-01779-3 model for the complex chemistry)\n", "model_paper = keras.Sequential(\n", - " [keras.layers.Input(shape=(12,), dtype=dtype),\n", + " [keras.layers.Input(shape=(8,), dtype=dtype),\n", " keras.layers.Dense(128, activation='relu', dtype=dtype),\n", " keras.layers.Dense(256, activation='relu', dtype=dtype),\n", " keras.layers.Dense(512, activation='relu', dtype=dtype),\n", " keras.layers.Dense(256, activation='relu', dtype=dtype),\n", - " keras.layers.Dense(12, dtype=dtype)\n", + " keras.layers.Dense(8, dtype=dtype)\n", " ])\n", "\n", "model_paper.compile(optimizer=optimizer_paper, loss = loss)\n", @@ -444,7 +453,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -514,7 +523,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -530,81 +539,6 @@ "data_file.close()" ] }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "species_columns = ['H', 'O', 'Charge', 'Ba', 'Cl', 'S', 'Sr', 'Barite', 'Celestite']" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/hannessigner/miniforge3/envs/ai/lib/python3.12/site-packages/sklearn/base.py:1474: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (2). Possibly due to duplicate points in X.\n", - " return fit_method(estimator, *args, **kwargs)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Amount class 0 before: 0.9521309523809524\n", - "Amount class 1 before: 0.04786904761904762\n" - ] - } - ], - "source": [ - "preprocess = preprocessing(func_dict_in=func_dict_in, func_dict_out=func_dict_out)\n", - "X, y = preprocess.cluster(df_design[species_columns], df_results[species_columns])\n", - "# X, y = preprocess.funcTranform(X, y)\n", - "\n", - "X_train, X_test, y_train, y_test = preprocess.split(X, y, ratio = 0.2)\n", - "X_train, y_train = preprocess.balancer(X_train, y_train, strategy = \"off\")\n", - "preprocess.scale_fit(X_train, y_train, scaling = \"individual\")\n", - "X_train, X_test, y_train, y_test = preprocess.scale_transform(X_train, X_test, y_train, y_test)\n", - "X_train, X_val, y_train, y_val = preprocess.split(X_train, y_train, ratio = 0.1)" - ] - }, - { - "cell_type": "code", - "execution_count": 52, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 52, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "timestep=250\n", - "plt.imshow(np.array(X[\"Barite\"][(timestep*2500):(timestep*2500+2500)]).reshape(50,50), origin='lower')\n", - "plt.contour(np.array(X[\"Class\"][(timestep*2500):(timestep*2500+2500)]).reshape(50,50), origin='lower', colors='red')\n" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -623,78 +557,80 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "species_columns = ['H', 'O', 'Ba', 'Cl', 'S', 'Sr', 'Barite', 'Celestite']" + ] + }, + { + "cell_type": "code", + "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/Users/hannessigner/miniforge3/envs/ai/lib/python3.12/site-packages/sklearn/base.py:1474: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (2). Possibly due to duplicate points in X.\n", + "/mnt/scratch/miniconda3/envs/model-training/lib/python3.12/site-packages/sklearn/base.py:1389: ConvergenceWarning: Number of distinct clusters (1) found smaller than n_clusters (2). Possibly due to duplicate points in X.\n", " return fit_method(estimator, *args, **kwargs)\n" ] }, { - "ename": "KeyError", - "evalue": "'S'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[53], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m X_train, X_val, X_test, y_train, y_val, y_test, scaler_X, scaler_y \u001b[38;5;241m=\u001b[39m \u001b[43mpreprocessing_training\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdf_design\u001b[49m\u001b[43m[\u001b[49m\u001b[43mspecies_columns\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdf_results\u001b[49m\u001b[43m[\u001b[49m\u001b[43mspecies_columns\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfunc_dict_in\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfunc_dict_out\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moff\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mglobal\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.1\u001b[39;49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Documents/Work/model-training/preprocessing.py:161\u001b[0m, in \u001b[0;36mpreprocessing_training\u001b[0;34m(df_design, df_targets, func_dict_in, func_dict_out, sampling, scaling, test_size)\u001b[0m\n\u001b[1;32m 158\u001b[0m df_design \u001b[38;5;241m=\u001b[39m clustering(df_design)\n\u001b[1;32m 159\u001b[0m df_targets \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mconcat([df_targets, df_design[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mClass\u001b[39m\u001b[38;5;124m'\u001b[39m]], axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m--> 161\u001b[0m df_design_log \u001b[38;5;241m=\u001b[39m \u001b[43mFuncTransform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc_dict_in\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfunc_dict_out\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_transform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdf_design\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 162\u001b[0m df_results_log \u001b[38;5;241m=\u001b[39m FuncTransform(func_dict_in, func_dict_out)\u001b[38;5;241m.\u001b[39mfit_transform(df_targets)\n\u001b[1;32m 164\u001b[0m X_train, X_test, y_train, y_test \u001b[38;5;241m=\u001b[39m sk\u001b[38;5;241m.\u001b[39mtrain_test_split(df_design_log, df_results_log, test_size \u001b[38;5;241m=\u001b[39m test_size, random_state\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m42\u001b[39m)\n", - "File \u001b[0;32m~/Documents/Work/model-training/preprocessing.py:63\u001b[0m, in \u001b[0;36mFuncTransform.fit_transform\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mfit_transform\u001b[39m(\u001b[38;5;28mself\u001b[39m, X, y\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 62\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfit(X)\n\u001b[0;32m---> 63\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Documents/Work/model-training/preprocessing.py:58\u001b[0m, in \u001b[0;36mFuncTransform.transform\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m X\u001b[38;5;241m.\u001b[39mkeys(): \n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mClass\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m key:\n\u001b[0;32m---> 58\u001b[0m X[key] \u001b[38;5;241m=\u001b[39m X[key]\u001b[38;5;241m.\u001b[39mapply(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunc_transform\u001b[49m\u001b[43m[\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m]\u001b[49m)\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m X\n", - "\u001b[0;31mKeyError\u001b[0m: 'S'" + "name": "stdout", + "output_type": "stream", + "text": [ + "Amount class 0 before: 0.9521309523809524\n", + "Amount class 1 before: 0.04786904761904762\n", + "Using Oversampling\n", + "Amount class 0 after: 0.5\n", + "Amount class 1 after: 0.5\n" ] } ], "source": [ - "X_train, X_val, X_test, y_train, y_val, y_test, scaler_X, scaler_y = preprocessing_training(df_design[species_columns], df_results[species_columns], func_dict_in, func_dict_out, \"off\", 'global', 0.1)" + "preprocess = preprocessing(func_dict_in=func_dict_in, func_dict_out=func_dict_out)\n", + "X, y = preprocess.cluster(df_design[species_columns], df_results[species_columns])\n", + "# X, y = preprocess.funcTranform(X, y)\n", + "\n", + "X_train, X_test, y_train, y_test = preprocess.split(X, y, ratio = 0.2)\n", + "X_train, y_train = preprocess.balancer(X_train, y_train, strategy = \"over\")\n", + "preprocess.scale_fit(X_train, y_train, scaling = \"individual\")\n", + "X_train, X_test, y_train, y_test = preprocess.scale_transform(X_train, X_test, y_train, y_test)\n", + "X_train, X_val, y_train, y_val = preprocess.split(X_train, y_train, ratio = 0.1)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "array([5.88371754e-02, 2.38285692e-01, 1.25266821e-01, 4.02648011e-05,\n", - " 5.71730222e-02, 2.38302374e-01, 9.25432038e-02, 3.77910581e-07,\n", - " 9.99694424e-01])" + "" ] }, - "execution_count": 37, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" - } - ], - "source": [ - "X_train.iloc[12, :-1].values" - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "metadata": {}, - "outputs": [ + }, { "data": { + "image/png": "", "text/plain": [ - "array([[1.11012434e+02, 5.55068087e+01, 3.55966726e-08, 3.89751302e-06,\n", - " 1.12795836e-02, 1.47982437e-04, 5.78389634e-03, 9.99927111e-04,\n", - " 1.00047941e+00]])" + "
" ] }, - "execution_count": 54, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "preprocess.scaler_X.inverse_transform(tf.keras.backend.constant(X_train.iloc[12, :-1].values.reshape(1, -1)))" + "timestep=250\n", + "plt.imshow(np.array(X[\"Barite\"][(timestep*2500):(timestep*2500+2500)]).reshape(50,50), origin='lower')\n", + "plt.contour(np.array(X[\"Class\"][(timestep*2500):(timestep*2500+2500)]).reshape(50,50), origin='lower', colors='red')" ] }, { @@ -706,7 +642,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -715,7 +651,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 71, "metadata": {}, "outputs": [], "source": [ @@ -727,6 +663,7 @@ " min_y = tf.convert_to_tensor(preprocess.scaler_y.min_, dtype=tf.float32)\n", "\n", " def loss(results, predicted):\n", + " \n", " # inverse min/max scaling\n", " predicted_inverse = predicted * scale_X + min_X\n", " results_inverse = results * scale_y + min_y\n", @@ -750,8 +687,8 @@ " huber_loss = tf.keras.losses.Huber()(results, predicted)\n", " \n", " # total loss\n", - " total_loss = h1 * huber_loss + h2 * dBa**2 + h3 * dSr**2 + h4 * h2o_ratio**2\n", - "\n", + " #total_loss = h1 * huber_loss + h2 * dBa**2 + h3 * dSr**2 + h4 * h2o_ratio**2\n", + " total_loss = huber_loss\n", " return total_loss\n", "\n", " return loss" @@ -759,12 +696,12 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 72, "metadata": {}, "outputs": [], "source": [ "model_simple.compile(optimizer=optimizer_simple, loss=custom_loss(preprocess, column_dict, 1, 1, 1, 1))#custom_loss(preprocess, column_dict))\n", - "# model_large.compile(optimizer=optimizer_large, loss=custom_loss(preprocess, column_dict))#custom_loss(preprocess, column_dict))" + "model_large.compile(optimizer=optimizer_large, loss=custom_loss(preprocess, column_dict, 1, 1, 1, 1))#custom_loss(preprocess, column_dict))" ] }, { @@ -776,7 +713,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 73, "metadata": {}, "outputs": [], "source": [ @@ -799,117 +736,125 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 74, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 1/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.6440 - val_loss: 0.5460\n", + "Epoch 1/50\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 10ms/step - loss: 0.0957 - val_loss: 0.0721\n", "Epoch 2/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.2363 - val_loss: 0.1153\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 0.0656 - val_loss: 0.0310\n", "Epoch 3/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.1663 - val_loss: 0.1348\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 0.0195 - val_loss: 8.7926e-04\n", "Epoch 4/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.1211 - val_loss: 0.1081\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 4.5016e-04 - val_loss: 5.9917e-05\n", "Epoch 5/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0994 - val_loss: 0.0932\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 10ms/step - loss: 4.7585e-05 - val_loss: 2.6424e-05\n", "Epoch 6/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0824 - val_loss: 0.0344\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 2.3970e-05 - val_loss: 1.4780e-05\n", "Epoch 7/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0787 - val_loss: 0.0577\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 10ms/step - loss: 1.3346e-05 - val_loss: 7.0396e-06\n", "Epoch 8/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0704 - val_loss: 0.0882\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 10ms/step - loss: 6.5093e-06 - val_loss: 3.9049e-06\n", "Epoch 9/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0672 - val_loss: 0.0456\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 10ms/step - loss: 4.3649e-06 - val_loss: 2.6283e-06\n", "Epoch 10/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0543 - val_loss: 0.0487\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 3.2264e-06 - val_loss: 1.9507e-06\n", "Epoch 11/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 2ms/step - loss: 0.0548 - val_loss: 0.0402\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 2.5372e-06 - val_loss: 1.5390e-06\n", "Epoch 12/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0484 - val_loss: 0.0576\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 2.5973e-06 - val_loss: 1.1950e-06\n", "Epoch 13/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0487 - val_loss: 0.0234\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 1.9292e-06 - val_loss: 9.5878e-07\n", "Epoch 14/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0499 - val_loss: 0.0357\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 10ms/step - loss: 2.3976e-06 - val_loss: 8.0331e-07\n", "Epoch 15/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0406 - val_loss: 0.0262\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 2.0785e-06 - val_loss: 7.0058e-07\n", "Epoch 16/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0388 - val_loss: 0.0258\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 10ms/step - loss: 1.7859e-06 - val_loss: 6.2885e-07\n", "Epoch 17/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0312 - val_loss: 0.0369\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 10ms/step - loss: 1.3826e-06 - val_loss: 5.7928e-07\n", "Epoch 18/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0340 - val_loss: 0.0571\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 10ms/step - loss: 1.4848e-06 - val_loss: 5.3876e-07\n", "Epoch 19/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 2ms/step - loss: 0.0349 - val_loss: 0.0306\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 1.1739e-06 - val_loss: 5.0183e-07\n", "Epoch 20/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0276 - val_loss: 0.0579\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 1.4790e-06 - val_loss: 4.7452e-07\n", "Epoch 21/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0255 - val_loss: 0.0264\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 1.4143e-06 - val_loss: 4.5366e-07\n", "Epoch 22/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 2ms/step - loss: 0.0233 - val_loss: 0.0334\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 1.7813e-06 - val_loss: 4.3588e-07\n", "Epoch 23/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0232 - val_loss: 0.0155\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 10ms/step - loss: 1.2874e-06 - val_loss: 4.2102e-07\n", "Epoch 24/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0230 - val_loss: 0.0194\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 9.3762e-07 - val_loss: 4.0813e-07\n", "Epoch 25/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0232 - val_loss: 0.0099\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 1.3506e-06 - val_loss: 3.9906e-07\n", "Epoch 26/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0195 - val_loss: 0.0136\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 10ms/step - loss: 1.0733e-06 - val_loss: 3.8785e-07\n", "Epoch 27/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0199 - val_loss: 0.0166\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 1.3375e-06 - val_loss: 3.8002e-07\n", "Epoch 28/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0185 - val_loss: 0.0133\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 1.3877e-06 - val_loss: 3.7374e-07\n", "Epoch 29/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0174 - val_loss: 0.0094\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 10ms/step - loss: 1.1366e-06 - val_loss: 3.6767e-07\n", "Epoch 30/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0170 - val_loss: 0.0148\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 9.4944e-07 - val_loss: 3.6651e-07\n", "Epoch 31/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0141 - val_loss: 0.0132\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 1.2333e-06 - val_loss: 3.5983e-07\n", "Epoch 32/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0161 - val_loss: 0.0078\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 9.5447e-07 - val_loss: 3.5675e-07\n", "Epoch 33/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0158 - val_loss: 0.0279\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 1.2573e-06 - val_loss: 3.4970e-07\n", "Epoch 34/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0152 - val_loss: 0.0106\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 1.3743e-06 - val_loss: 3.4671e-07\n", "Epoch 35/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0132 - val_loss: 0.0190\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 1.2499e-06 - val_loss: 3.4395e-07\n", "Epoch 36/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0140 - val_loss: 0.0184\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 1.3762e-06 - val_loss: 3.4097e-07\n", "Epoch 37/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0116 - val_loss: 0.0092\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 1.5823e-06 - val_loss: 3.3895e-07\n", "Epoch 38/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0106 - val_loss: 0.0178\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 1.0824e-06 - val_loss: 3.3658e-07\n", "Epoch 39/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0113 - val_loss: 0.0113\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 1.4234e-06 - val_loss: 3.3699e-07\n", "Epoch 40/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0094 - val_loss: 0.0090\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 1.7752e-06 - val_loss: 3.3318e-07\n", "Epoch 41/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0093 - val_loss: 0.0097\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 10ms/step - loss: 1.4420e-06 - val_loss: 3.3185e-07\n", "Epoch 42/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0092 - val_loss: 0.0081\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 10ms/step - loss: 1.6145e-06 - val_loss: 3.3050e-07\n", "Epoch 43/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0085 - val_loss: 0.0093\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 10ms/step - loss: 1.0874e-06 - val_loss: 3.2977e-07\n", "Epoch 44/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0083 - val_loss: 0.0082\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 1.5824e-06 - val_loss: 3.2917e-07\n", "Epoch 45/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0091 - val_loss: 0.0135\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 10ms/step - loss: 1.0676e-06 - val_loss: 3.2717e-07\n", "Epoch 46/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0081 - val_loss: 0.0065\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 10ms/step - loss: 1.4321e-06 - val_loss: 3.2660e-07\n", "Epoch 47/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0076 - val_loss: 0.0100\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 1.4732e-06 - val_loss: 3.2621e-07\n", "Epoch 48/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0091 - val_loss: 0.0064\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 1.1202e-06 - val_loss: 3.2509e-07\n", "Epoch 49/50\n", - "\u001b[1m886/886\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 2ms/step - loss: 0.0079 - val_loss: 0.0070\n", - "Training took 79.01837086677551 seconds\n" + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 9ms/step - loss: 1.2688e-06 - val_loss: 3.2409e-07\n", + "Epoch 50/50\n", + "\u001b[1m1688/1688\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 10ms/step - loss: 1.2193e-06 - val_loss: 3.2360e-07\n", + "Training took 821.0099844932556 seconds\n" ] } ], "source": [ - "model_training(model_simple)" + "model_training(model_large)" ] }, { @@ -921,7 +866,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 55, "metadata": {}, "outputs": [], "source": [ @@ -947,14 +892,67 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 65, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[1m3938/3938\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 311us/step\n" + "\u001b[1m 1/3938\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m1:58\u001b[0m 30ms/step" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m3938/3938\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m10s\u001b[0m 2ms/step\n" + ] + }, + { + "data": { + "text/plain": [ + "array([[-1.06938938e-02, 4.13956866e-02, 2.79493211e-03, ...,\n", + " 3.49314094e-01, -2.77905501e-02, -1.75736845e-03],\n", + " [-2.60199383e-02, 1.21183231e-01, 1.03316315e-01, ...,\n", + " 9.01593208e-01, -1.07098348e-01, -3.44314128e-02],\n", + " [-6.70573022e-03, 3.49318655e-03, -2.75319908e-04, ...,\n", + " 8.38631839e-02, 2.93988874e-03, 5.30265123e-02],\n", + " ...,\n", + " [-1.60861928e-02, 6.71150684e-02, 5.86529588e-03, ...,\n", + " 5.75501978e-01, -5.87367304e-02, -2.30707154e-02],\n", + " [-9.72283073e-03, 8.33255332e-03, -5.69229014e-05, ...,\n", + " 1.05915517e-01, 8.14233907e-04, 4.98460531e-02],\n", + " [-1.26683740e-02, 5.05815037e-02, 3.94532410e-03, ...,\n", + " 4.28284466e-01, -3.94552983e-02, -9.58428532e-03]], dtype=float32)" + ] + }, + "execution_count": 65, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_large.predict(X_test.loc[:, X_test.columns != \"Class\"]) " + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m 1/3938\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m1:42\u001b[0m 26ms/step" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m3938/3938\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 286us/step\n" ] } ], @@ -964,7 +962,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 63, "metadata": {}, "outputs": [ { @@ -973,18 +971,18 @@ "0.0" ] }, - "execution_count": 21, + "execution_count": 63, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "len(mass_balance_results[mass_balance_results < 1e-2]) / len(mass_balance_results)" + "len(mass_balance_results[mass_balance_results < 1e-5]) / len(mass_balance_results)" ] }, { "cell_type": "code", - "execution_count": 99, + "execution_count": 58, "metadata": {}, "outputs": [ { @@ -993,7 +991,7 @@ "Series([], dtype: float64)" ] }, - "execution_count": 99, + "execution_count": 58, "metadata": {}, "output_type": "execute_result" } @@ -1514,19 +1512,19 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 70, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[1m8/8\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step \n" + "\u001b[1m8/8\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 3ms/step \n" ] }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -1536,7 +1534,7 @@ }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -1548,9 +1546,9 @@ "source": [ "import matplotlib.pyplot as plt\n", "\n", - "species = \"Ba\"\n", + "species = \"Barite\"\n", "iterations = 250\n", - "cell_offset = 11\n", + "cell_offset = 9\n", "y_design = []\n", "y_results = []\n", "y_differences = []\n", @@ -1569,7 +1567,7 @@ "y_design = pd.DataFrame(y_design)\n", "y_results = pd.DataFrame(y_results)\n", "\n", - "prediction = model_simple.predict(y_design.iloc[:, y_design.columns != \"Class\"])\n", + "prediction = model_large.predict(y_design.iloc[:, y_design.columns != \"Class\"])\n", "prediction = pd.DataFrame(prediction, columns = y_results.columns)\n", "\n", "# y_results_back, prediction = preprocess.funcInverse(y_results, prediction)\n", @@ -1875,23 +1873,23 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[1m15641/15641\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 351us/step - loss: 5.1847e-07\n" + "\u001b[1m3938/3938\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 827us/step - loss: 70.7642\n" ] }, { "data": { "text/plain": [ - "3.571243496480747e-07" + "70.69287872314453" ] }, - "execution_count": 48, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } @@ -1903,23 +1901,23 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[1m15452/15452\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 385us/step - loss: 5.2313e-07\n" + "\u001b[1m3747/3747\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 733us/step - loss: 67.2305\n" ] }, { "data": { "text/plain": [ - "3.601293485644419e-07" + "67.27115631103516" ] }, - "execution_count": 49, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } @@ -1931,40 +1929,30 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[1m94/94\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 4ms/step - loss: 4.0710e-05\n" + "\u001b[1m 1/192\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m2s\u001b[0m 12ms/step - loss: 148.6424" ] - } - ], - "source": [ - "mass_balance = mass_balance(model_simple, X_test, scaler_X, func_dict_in, func_dict_out)" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "metadata": {}, - "outputs": [ + }, { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[1m189/189\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 393us/step - loss: 1.2226e-07\n" + "\u001b[1m192/192\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 749us/step - loss: 139.3093\n" ] }, { "data": { "text/plain": [ - "1.1114495634956256e-07" + "137.7884521484375" ] }, - "execution_count": 50, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } @@ -1994,7 +1982,7 @@ ], "metadata": { "kernelspec": { - "display_name": "ai", + "display_name": "model-training", "language": "python", "name": "python3" }, @@ -2008,7 +1996,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.8" + "version": "3.12.2" } }, "nbformat": 4, diff --git a/optuna_runs.py b/optuna_runs.py new file mode 100644 index 0000000..0df7b3f --- /dev/null +++ b/optuna_runs.py @@ -0,0 +1,225 @@ +import keras +from keras.layers import Dense, Dropout, Input,BatchNormalization +import tensorflow as tf +import h5py +import numpy as np +import pandas as pd +import time +import sklearn.model_selection as sk +import matplotlib.pyplot as plt +from sklearn.cluster import KMeans +from sklearn.pipeline import Pipeline, make_pipeline +from sklearn.preprocessing import StandardScaler, MinMaxScaler +from imblearn.over_sampling import SMOTE +from imblearn.under_sampling import RandomUnderSampler +from imblearn.over_sampling import RandomOverSampler +from collections import Counter +import os +from preprocessing import * +from sklearn import set_config +from importlib import reload +set_config(transform_output = "pandas") + +dtype = "float32" +activation = "relu" + +lr = 0.001 +batch_size = 512 +epochs = 50 # default 400 epochs + +lr_schedule = keras.optimizers.schedules.ExponentialDecay( + initial_learning_rate=lr, + decay_steps=2000, + decay_rate=0.9, + staircase=True +) + +optimizer_simple = keras.optimizers.Adam(learning_rate=lr_schedule) +optimizer_large = keras.optimizers.Adam(learning_rate=lr_schedule) +optimizer_paper = keras.optimizers.Adam(learning_rate=lr_schedule) + +sample_fraction = 0.8 + +# small model +model_simple = keras.Sequential( + [ + keras.Input(shape = (9,), dtype = "float32"), + keras.layers.Dense(units = 128, activation = "linear", dtype = "float32"), + # Dropout(0.2), + keras.layers.Dense(units = 128, activation = "elu", dtype = "float32"), + keras.layers.Dense(units = 9, dtype = "float32") + ] +) + +def Safelog(val): + # get range of vector + if val > 0: + return np.log10(val) + elif val < 0: + return -np.log10(-val) + else: + return 0 + +def Safeexp(val): + if val > 0: + return -10 ** -val + elif val < 0: + return 10 ** val + else: + return 0 + +# ? Why does the charge is using another logarithm than the other species + +func_dict_in = { + "H" : np.log1p, + "O" : np.log1p, + "Charge" : Safelog, + "H_0_" : np.log1p, + "O_0_" : np.log1p, + "Ba" : np.log1p, + "Cl" : np.log1p, + "S_2_" : np.log1p, + "S_6_" : np.log1p, + "Sr" : np.log1p, + "Barite" : np.log1p, + "Celestite" : np.log1p, +} + +func_dict_out = { + "H" : np.expm1, + "O" : np.expm1, + "Charge" : Safeexp, + "H_0_" : np.expm1, + "O_0_" : np.expm1, + "Ba" : np.expm1, + "Cl" : np.expm1, + "S_2_" : np.expm1, + "S_6_" : np.expm1, + "Sr" : np.expm1, + "Barite" : np.expm1, + "Celestite" : np.expm1, +} + +# os.chdir('/mnt/beegfs/home/signer/projects/model-training') +data_file = h5py.File("barite_50_4_corner.h5") + +design = data_file["design"] +results = data_file["result"] + +df_design = pd.DataFrame(np.array(design["data"]).transpose(), columns = np.array(design["names"].asstr())) +df_results = pd.DataFrame(np.array(results["data"]).transpose(), columns = np.array(results["names"].asstr())) + +data_file.close() + +species_columns = ['H', 'O', 'Charge', 'Ba', 'Cl', 'S', 'Sr', 'Barite', 'Celestite'] + +preprocess = preprocessing(func_dict_in=func_dict_in, func_dict_out=func_dict_out) +X, y = preprocess.cluster(df_design[species_columns], df_results[species_columns]) +# X, y = preprocess.funcTranform(X, y) + +X_train, X_test, y_train, y_test = preprocess.split(X, y, ratio = 0.2) +X_train, y_train = preprocess.balancer(X_train, y_train, strategy = "over") +preprocess.scale_fit(X_train, y_train, scaling = "individual") +X_train, X_test, y_train, y_test = preprocess.scale_transform(X_train, X_test, y_train, y_test) +X_train, X_val, y_train, y_val = preprocess.split(X_train, y_train, ratio = 0.1) + +column_dict = {"Ba": X.columns.get_loc("Ba"), "Barite":X.columns.get_loc("Barite"), "Sr":X.columns.get_loc("Sr"), "Celestite":X.columns.get_loc("Celestite"), "H":X.columns.get_loc("H"), "H":X.columns.get_loc("H"), "O":X.columns.get_loc("O")} + +def custom_loss(preprocess, column_dict, h1, h2, h3, h4): + # extract the scaling parameters + scale_X = tf.convert_to_tensor(preprocess.scaler_X.scale_, dtype=tf.float32) + min_X = tf.convert_to_tensor(preprocess.scaler_X.min_, dtype=tf.float32) + scale_y = tf.convert_to_tensor(preprocess.scaler_y.scale_, dtype=tf.float32) + min_y = tf.convert_to_tensor(preprocess.scaler_y.min_, dtype=tf.float32) + + def loss(results, predicted): + # inverse min/max scaling + predicted_inverse = predicted * scale_X + min_X + results_inverse = results * scale_y + min_y + + # mass balance + dBa = tf.keras.backend.abs( + (predicted_inverse[:, column_dict["Ba"]] + predicted_inverse[:, column_dict["Barite"]]) - + (results_inverse[:, column_dict["Ba"]] + results_inverse[:, column_dict["Barite"]]) + ) + dSr = tf.keras.backend.abs( + (predicted_inverse[:, column_dict["Sr"]] + predicted_inverse[:, column_dict["Celestite"]]) - + (results_inverse[:, column_dict["Sr"]] + results_inverse[:, column_dict["Celestite"]]) + ) + + # H/O ratio has to be 2 + h2o_ratio = tf.keras.backend.abs( + (predicted_inverse[:, column_dict["H"]] / predicted_inverse[:, column_dict["O"]]) - 2 + ) + + # huber loss + huber_loss = tf.keras.losses.Huber()(results, predicted) + + # total loss + total_loss = h1 * huber_loss + h2 * dBa**2 + h3 * dSr**2 #+ h4 * h2o_ratio**2 + + return total_loss + + return loss + +def mass_balance(model, X, preprocess): + + # predict the chemistry + columns = X.iloc[:, X.columns != "Class"].columns + prediction = pd.DataFrame(model.predict(X[columns]), columns=columns) + + # backtransform min/max + X = pd.DataFrame(preprocess.scaler_X.inverse_transform(X.iloc[:, X.columns != "Class"]), columns=columns) + prediction = pd.DataFrame(preprocess.scaler_y.inverse_transform(prediction), columns=columns) + + # calculate mass balance dBa = np.abs((prediction["Ba"] + prediction["Barite"]) - (X["Ba"] + X["Barite"])) + dSr = np.abs((prediction["Sr"] + prediction["Celestite"]) - (X["Sr"] + X["Celestite"])) + + return dBa + dSr + +import optuna + +def create_model(model, preprocess, h1, h2, h3, h4): + + model.compile(optimizer=optimizer_simple, loss=custom_loss(preprocess, column_dict, h1, h2, h3, h4)) + + return model + + +def objective(trial, preprocess, X_train, y_train, X_val, y_val, X_test, y_test): + h1 = trial.suggest_float("h1", 0.1, 10) + h2 = trial.suggest_float("h2", 0.1, 10) + h3 = trial.suggest_float("h3", 0.1, 10) + h4 = trial.suggest_float("h4", 0.1, 10) + + model = create_model(model_simple, preprocess, h1, h2, h3, h4) + + callback = keras.callbacks.EarlyStopping(monitor='loss', patience=3) + history = model.fit(X_train.loc[:, X_train.columns != "Class"], + y_train.loc[:, y_train.columns != "Class"], + batch_size=batch_size, + epochs=50, + validation_data=(X_val.loc[:, X_val.columns != "Class"], y_val.loc[:, y_val.columns != "Class"]), + callbacks=[callback]) + + prediction_loss = model.evaluate(X_test.loc[:, X_test.columns != "Class"], y_test.loc[:, y_test.columns != "Class"]) + mass_balance_results = mass_balance(model, X_test, preprocess) + + mass_balance_ratio = len(mass_balance_results[mass_balance_results < 1e-5]) / len(mass_balance_results) + + return prediction_loss, mass_balance_ratio + +if __name__ == "__main__": + study = optuna.create_study(storage="sqlite:///model_optimization.db", study_name="model_optimization", directions=["minimize", "maximize"]) + study.optimize(lambda trial: objective(trial, preprocess, X_train, y_train, X_val, y_val, X_test, y_test), n_trials=1000) + + print("Number of finished trials: ", len(study.trials)) + + print("Best trial:") + trial = study.best_trial + + print(" Value: ", trial.value) + + print(" Params: ") + for key, value in trial.params.items(): + print(" {}: {}".format(key, value)) \ No newline at end of file