mirror of
https://git.gfz-potsdam.de/naaice/model-training.git
synced 2025-12-16 08:08:22 +01:00
update experiments
This commit is contained in:
parent
d678c0bfde
commit
69355a1e4e
@ -4,6 +4,7 @@ import matplotlib.pyplot as plt
|
|||||||
import matplotlib.colors as mcolors
|
import matplotlib.colors as mcolors
|
||||||
import matplotlib.patches as mpatches
|
import matplotlib.patches as mpatches
|
||||||
import pickle
|
import pickle
|
||||||
|
import csv
|
||||||
|
|
||||||
|
|
||||||
###### Experimental parameters
|
###### Experimental parameters
|
||||||
@ -27,6 +28,8 @@ df_results = pd.DataFrame(
|
|||||||
)
|
)
|
||||||
|
|
||||||
data_file.close()
|
data_file.close()
|
||||||
|
|
||||||
|
# remove charge as species
|
||||||
df_design.drop("Charge", axis=1, inplace=True, errors="ignore")
|
df_design.drop("Charge", axis=1, inplace=True, errors="ignore")
|
||||||
df_results.drop("Charge", axis=1, inplace=True, errors="ignore")
|
df_results.drop("Charge", axis=1, inplace=True, errors="ignore")
|
||||||
|
|
||||||
@ -122,36 +125,11 @@ model.compile(
|
|||||||
|
|
||||||
###### train model
|
###### train model
|
||||||
|
|
||||||
epochs = 200
|
epochs = 3
|
||||||
|
|
||||||
history = model_training(model, epochs=epochs)
|
history = model_training(model, epochs=epochs)
|
||||||
|
|
||||||
|
|
||||||
###### evaluate model
|
|
||||||
|
|
||||||
|
|
||||||
results = mass_balance_evaluation(model, X_test, preprocess)
|
|
||||||
mass_balance_ratio(results, threshold=1e-5)
|
|
||||||
|
|
||||||
def test_model(model, X_test, y_test):
|
|
||||||
|
|
||||||
X_test.reset_index(inplace=True, drop=True)
|
|
||||||
y_test.reset_index(inplace=True, drop=True)
|
|
||||||
all = model.evaluate(X_test.loc[:, X_test.columns != "Class"], y_test.loc[:, y_test.columns != "Class"])
|
|
||||||
class_0 = model.evaluate(X_test[X_test["Class"] == 0].iloc[:, X_test.columns != "Class"], y_test[X_test["Class"] == 0].iloc[:, y_test.columns != "Class"])
|
|
||||||
class_1 = model.evaluate(
|
|
||||||
X_test[X_test["Class"] == 1].iloc[:, :-1], y_test[X_test["Class"] == 1].iloc[:, :-1])
|
|
||||||
|
|
||||||
print("metric all data: ", all)
|
|
||||||
print("metric class 0: ", class_0)
|
|
||||||
print("metric class 1: ", class_1)
|
|
||||||
|
|
||||||
|
|
||||||
test_model(model, X_test, y_test)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
###### save model and history
|
###### save model and history
|
||||||
|
|
||||||
delimiter = "_"
|
delimiter = "_"
|
||||||
@ -160,4 +138,38 @@ file_name = "history_" + idx_string
|
|||||||
with open('../results/'+file_name, 'wb') as file_pi:
|
with open('../results/'+file_name, 'wb') as file_pi:
|
||||||
pickle.dump(history.history, file_pi)
|
pickle.dump(history.history, file_pi)
|
||||||
|
|
||||||
model.save_weights("../results/models/model_"+idx_string + ".weights.h5")
|
model.save_weights("../results/models/model_"+idx_string + ".weights.h5")
|
||||||
|
|
||||||
|
|
||||||
|
###### evaluate model
|
||||||
|
|
||||||
|
results = mass_balance_evaluation(model, X_test, preprocess)
|
||||||
|
proportion = mass_balance_ratio(results, threshold=1e-5)
|
||||||
|
|
||||||
|
X_test.reset_index(inplace=True, drop=True)
|
||||||
|
y_test.reset_index(inplace=True, drop=True)
|
||||||
|
all_classes = model.evaluate(X_test.loc[:, X_test.columns != "Class"], y_test.loc[:, y_test.columns != "Class"])
|
||||||
|
class_0 = model.evaluate(X_test[X_test["Class"] == 0].iloc[:, X_test.columns != "Class"], y_test[X_test["Class"] == 0].iloc[:, y_test.columns != "Class"])
|
||||||
|
class_1 = model.evaluate(X_test[X_test["Class"] == 1].iloc[:, :-1], y_test[X_test["Class"] == 1].iloc[:, :-1])
|
||||||
|
|
||||||
|
print("metric all data: ", all_classes)
|
||||||
|
print("metric class 0: ", class_0)
|
||||||
|
print("metric class 1: ", class_1)
|
||||||
|
|
||||||
|
|
||||||
|
# Save evaluation results to a file
|
||||||
|
results_file_name = "../results/evaluation_" + idx_string + ".csv"
|
||||||
|
with open(results_file_name, mode="w", newline="") as results_file:
|
||||||
|
writer = csv.writer(results_file)
|
||||||
|
writer.writerow(["Metric", "Value"])
|
||||||
|
writer.writerow(["Mass balance fulfilled (all classes)", proportion["overall"]])
|
||||||
|
writer.writerow(["Mass balance fulfilled (class 0)", proportion["class_0"]])
|
||||||
|
writer.writerow(["Mass balance fulfilled (class 1)", proportion["class_1"]])
|
||||||
|
writer.writerow(["Metrics (all classes)", all_classes])
|
||||||
|
writer.writerow(["Metrics (class 0)", class_0])
|
||||||
|
writer.writerow(["Metrics (class 1)", class_1])
|
||||||
|
|
||||||
|
results_file.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user