From c4da86ef98eff06cbe33a9278cee808bb4793384 Mon Sep 17 00:00:00 2001 From: straile Date: Sun, 3 Nov 2024 18:02:32 +0100 Subject: [PATCH] fix: dynamic training buffer column number --- src/Chemistry/SurrogateModels/AI_functions.cpp | 7 ++++--- src/poet.cpp | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/Chemistry/SurrogateModels/AI_functions.cpp b/src/Chemistry/SurrogateModels/AI_functions.cpp index 2adfc269a..1ba4c2d80 100644 --- a/src/Chemistry/SurrogateModels/AI_functions.cpp +++ b/src/Chemistry/SurrogateModels/AI_functions.cpp @@ -612,8 +612,10 @@ void parallel_training(EigenModel* Eigen_model, EigenModel* Eigen_model_reactive // Get the necessary training data std::cout << "AI: Training thread: Getting training data" << std::endl; // Initialize training data input and targets - std::vector> inputs(9, std::vector(params.training_data_size)); - std::vector> targets(9, std::vector(params.training_data_size)); + std::vector> inputs(training_data_buffer->x.size(), + std::vector(params.training_data_size)); + std::vector> targets(training_data_buffer->x.size(), + std::vector(params.training_data_size)); int buffer_size = training_data_buffer->x[0].size(); // If clustering is used, check the current cluster @@ -676,7 +678,6 @@ void parallel_training(EigenModel* Eigen_model, EigenModel* Eigen_model_reactive Python_Keras_train(inputs, targets, train_cluster, model_name, params); if (!params.use_Keras_predictions) { - // TODO UPDATE EIGEN MODEL CLUSTER SPECIFIC std::cout << "AI: Training thread: Update shared model weights" << std::endl; std::vector>> cpp_weights = Python_Keras_get_weights(model_name); Eigen_model_mutex->lock(); diff --git a/src/poet.cpp b/src/poet.cpp index d8c22cb71..fd60d78a1 100644 --- a/src/poet.cpp +++ b/src/poet.cpp @@ -675,7 +675,7 @@ int main(int argc, char *argv[]) { /* Use dht species for model input and output */ R["ai_surrogate_species"] = init_list.getChemistryInit().dht_species.getNames(); - +R.parseEval("ai_surrogate_species <- ai_surrogate_species[ai_surrogate_species != \"Charge\"]"); const std::string ai_surrogate_input_script = init_list.getChemistryInit().ai_surrogate_input_script; @@ -718,7 +718,7 @@ int main(int argc, char *argv[]) { MSG("K-Means clustering will be used for the AI surrogate") } if (!Rcpp::as(R.parseEval("exists(\"model_reactive_file_path\")"))) { - R.parseEval("model_reactive_file_path <- model_reactive_file_path"); + R.parseEval("model_reactive_file_path <- model_file_path"); } MSG("AI: Initialize Python for the AI surrogate functions"); std::string python_keras_file = std::string(SRC_DIR) +