From 110d5c810b89f06bc76af1fa53cf9414b21ca9a3 Mon Sep 17 00:00:00 2001 From: straile Date: Mon, 4 Nov 2024 15:40:41 +0100 Subject: [PATCH] fix: roll back to functioning state --- .../AI_Python_functions/keras_AI_surrogate.py | 8 +++++++- src/Chemistry/SurrogateModels/AI_functions.cpp | 3 ++- src/Chemistry/WorkerFunctions.cpp | 18 +++++++++--------- src/poet.cpp | 12 +++++++++++- src/poet.hpp.in | 1 + 5 files changed, 30 insertions(+), 12 deletions(-) diff --git a/src/Chemistry/SurrogateModels/AI_Python_functions/keras_AI_surrogate.py b/src/Chemistry/SurrogateModels/AI_Python_functions/keras_AI_surrogate.py index 91a2de3bd..f994a7943 100644 --- a/src/Chemistry/SurrogateModels/AI_Python_functions/keras_AI_surrogate.py +++ b/src/Chemistry/SurrogateModels/AI_Python_functions/keras_AI_surrogate.py @@ -16,7 +16,13 @@ def initiate_model(model_file_path): model = tf.keras.models.load_model(model_file_path) return model -def prediction_step(model, model_reactive, x, cluster_labels, batch_size): +def prediction_step(model, model_reactive, x, cluster_labels, batch_size) + # Catch input size mismatches + model_input_shape = model.input_shape[1:] + if model_input_shape != x.shape[1:]: + print(f"Input data size {x.shape[1:]} does not match model input size {model_input_shape}", + flush=True) + # Predict separately if clustering is used if cluster_labels: cluster_labels = np.asarray(cluster_labels, dtype=bool) diff --git a/src/Chemistry/SurrogateModels/AI_functions.cpp b/src/Chemistry/SurrogateModels/AI_functions.cpp index 1ba4c2d80..95bbbf795 100644 --- a/src/Chemistry/SurrogateModels/AI_functions.cpp +++ b/src/Chemistry/SurrogateModels/AI_functions.cpp @@ -344,7 +344,8 @@ std::vector Eigen_predict_clustered(const EigenModel& model, const Eigen if (num_features != model.weight_matrices[0].cols() || num_features != model_reactive.weight_matrices[0].cols()) { throw std::runtime_error("Input data size " + std::to_string(num_features) + - " does not match model input layer sizes"); + " does not match model input layer sizes" + std::to_string(model.weight_matrices[0].cols()) + + " / " + std::to_string(model_reactive.weight_matrices[0].cols())); } // Convert input data to Eigen matrix diff --git a/src/Chemistry/WorkerFunctions.cpp b/src/Chemistry/WorkerFunctions.cpp index 335132d0a..9d178021a 100644 --- a/src/Chemistry/WorkerFunctions.cpp +++ b/src/Chemistry/WorkerFunctions.cpp @@ -161,6 +161,15 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, mpi_buffer.begin() + this->prop_count * (wp_i + 1)); } + if (this->ai_surrogate_enabled) { + // Map valid predictions from the ai surrogate in the workpackage + for (int i = 0; i < s_curr_wp.size; i++) { + if (this->ai_surrogate_validity_vector[wp_start_index + i] == 1) { + s_curr_wp.mapping[i] = CHEM_AISURR; + } + } + } + // std::cout << this->comm_rank << ":" << counter++ << std::endl; if (dht_enabled || interp_enabled) { dht->prepareKeys(s_curr_wp.input, dt); @@ -178,15 +187,6 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, interp->tryInterpolation(s_curr_wp); } - if (this->ai_surrogate_enabled) { - // Map valid predictions from the ai surrogate in the workpackage - for (int i = 0; i < s_curr_wp.size; i++) { - if (this->ai_surrogate_validity_vector[wp_start_index + i] == 1) { - s_curr_wp.mapping[i] = CHEM_AISURR; - } - } - } - phreeqc_time_start = MPI_Wtime(); diff --git a/src/poet.cpp b/src/poet.cpp index fd60d78a1..9bafe7b43 100644 --- a/src/poet.cpp +++ b/src/poet.cpp @@ -457,9 +457,14 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms, if (params.use_ai_surrogate && !params.disable_training) { // Add values for which the predictions were invalid // to training data buffer - MSG("AI: Add invalid predictions to training data buffer"); + MSG("AI: Add to training data buffer"); std::vector> invalid_x = R.parseEval("get_invalid_values(predictors_scaled, validity_vector)"); + + if (!params.train_only_invalid) { + // Use all values if not specified otherwise + R.parseEval("validity_vector[] <- 0"); + } R.parseEval("target_scaled <- preprocess(state_C[ai_surrogate_species])"); std::vector> invalid_y = @@ -675,6 +680,8 @@ int main(int argc, char *argv[]) { /* Use dht species for model input and output */ R["ai_surrogate_species"] = init_list.getChemistryInit().dht_species.getNames(); + +// TODO REMOVE!! R.parseEval("ai_surrogate_species <- ai_surrogate_species[ai_surrogate_species != \"Charge\"]"); const std::string ai_surrogate_input_script = init_list.getChemistryInit().ai_surrogate_input_script; @@ -709,6 +716,9 @@ R.parseEval("ai_surrogate_species <- ai_surrogate_species[ai_surrogate_species ! if (Rcpp::as(R.parseEval("exists(\"disable_training\")"))) { run_params.disable_training = R["disable_training"]; } + if (Rcpp::as(R.parseEval("exists(\"train_only_invalid\")"))) { + run_params.train_only_invalid = R["train_only_invalid"]; + } if (Rcpp::as(R.parseEval("exists(\"save_model_path\")"))) { run_params.save_model_path = Rcpp::as(R["save_model_path"]); MSG("AI: Model will be saved as \"" + run_params.save_model_path + "\""); diff --git a/src/poet.hpp.in b/src/poet.hpp.in index b3878d67f..82ef4d26b 100644 --- a/src/poet.hpp.in +++ b/src/poet.hpp.in @@ -74,6 +74,7 @@ struct RuntimeParameters { bool disable_training; // Can be set in the R input script bool use_k_means_clustering; // Can be set in the R input script bool use_Keras_predictions; // Can be set in the R input script + bool train_only_invalid; // Can be set in the R input script int batch_size; // Can be set in the R input script int training_epochs; // Can be set in the R input script int training_data_size; // Can be set in the R input script