mirror of
https://git.gfz-potsdam.de/naaice/model-training.git
synced 2025-12-13 10:28:22 +01:00
update experiments
This commit is contained in:
parent
d678c0bfde
commit
69355a1e4e
@ -4,6 +4,7 @@ import matplotlib.pyplot as plt
|
||||
import matplotlib.colors as mcolors
|
||||
import matplotlib.patches as mpatches
|
||||
import pickle
|
||||
import csv
|
||||
|
||||
|
||||
###### Experimental parameters
|
||||
@ -27,6 +28,8 @@ df_results = pd.DataFrame(
|
||||
)
|
||||
|
||||
data_file.close()
|
||||
|
||||
# remove charge as species
|
||||
df_design.drop("Charge", axis=1, inplace=True, errors="ignore")
|
||||
df_results.drop("Charge", axis=1, inplace=True, errors="ignore")
|
||||
|
||||
@ -122,36 +125,11 @@ model.compile(
|
||||
|
||||
###### train model
|
||||
|
||||
epochs = 200
|
||||
epochs = 3
|
||||
|
||||
history = model_training(model, epochs=epochs)
|
||||
|
||||
|
||||
###### evaluate model
|
||||
|
||||
|
||||
results = mass_balance_evaluation(model, X_test, preprocess)
|
||||
mass_balance_ratio(results, threshold=1e-5)
|
||||
|
||||
def test_model(model, X_test, y_test):
|
||||
|
||||
X_test.reset_index(inplace=True, drop=True)
|
||||
y_test.reset_index(inplace=True, drop=True)
|
||||
all = model.evaluate(X_test.loc[:, X_test.columns != "Class"], y_test.loc[:, y_test.columns != "Class"])
|
||||
class_0 = model.evaluate(X_test[X_test["Class"] == 0].iloc[:, X_test.columns != "Class"], y_test[X_test["Class"] == 0].iloc[:, y_test.columns != "Class"])
|
||||
class_1 = model.evaluate(
|
||||
X_test[X_test["Class"] == 1].iloc[:, :-1], y_test[X_test["Class"] == 1].iloc[:, :-1])
|
||||
|
||||
print("metric all data: ", all)
|
||||
print("metric class 0: ", class_0)
|
||||
print("metric class 1: ", class_1)
|
||||
|
||||
|
||||
test_model(model, X_test, y_test)
|
||||
|
||||
|
||||
|
||||
|
||||
###### save model and history
|
||||
|
||||
delimiter = "_"
|
||||
@ -160,4 +138,38 @@ file_name = "history_" + idx_string
|
||||
with open('../results/'+file_name, 'wb') as file_pi:
|
||||
pickle.dump(history.history, file_pi)
|
||||
|
||||
model.save_weights("../results/models/model_"+idx_string + ".weights.h5")
|
||||
model.save_weights("../results/models/model_"+idx_string + ".weights.h5")
|
||||
|
||||
|
||||
###### evaluate model
|
||||
|
||||
results = mass_balance_evaluation(model, X_test, preprocess)
|
||||
proportion = mass_balance_ratio(results, threshold=1e-5)
|
||||
|
||||
X_test.reset_index(inplace=True, drop=True)
|
||||
y_test.reset_index(inplace=True, drop=True)
|
||||
all_classes = model.evaluate(X_test.loc[:, X_test.columns != "Class"], y_test.loc[:, y_test.columns != "Class"])
|
||||
class_0 = model.evaluate(X_test[X_test["Class"] == 0].iloc[:, X_test.columns != "Class"], y_test[X_test["Class"] == 0].iloc[:, y_test.columns != "Class"])
|
||||
class_1 = model.evaluate(X_test[X_test["Class"] == 1].iloc[:, :-1], y_test[X_test["Class"] == 1].iloc[:, :-1])
|
||||
|
||||
print("metric all data: ", all_classes)
|
||||
print("metric class 0: ", class_0)
|
||||
print("metric class 1: ", class_1)
|
||||
|
||||
|
||||
# Save evaluation results to a file
|
||||
results_file_name = "../results/evaluation_" + idx_string + ".csv"
|
||||
with open(results_file_name, mode="w", newline="") as results_file:
|
||||
writer = csv.writer(results_file)
|
||||
writer.writerow(["Metric", "Value"])
|
||||
writer.writerow(["Mass balance fulfilled (all classes)", proportion["overall"]])
|
||||
writer.writerow(["Mass balance fulfilled (class 0)", proportion["class_0"]])
|
||||
writer.writerow(["Mass balance fulfilled (class 1)", proportion["class_1"]])
|
||||
writer.writerow(["Metrics (all classes)", all_classes])
|
||||
writer.writerow(["Metrics (class 0)", class_0])
|
||||
writer.writerow(["Metrics (class 1)", class_1])
|
||||
|
||||
results_file.close()
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user