From 4d254250e15f2faf69072e21e086d148a23b93a6 Mon Sep 17 00:00:00 2001 From: straile Date: Sun, 20 Oct 2024 12:38:58 +0200 Subject: [PATCH] fix: set training wait predicate with buffer threshold check --- .../AI_Python_functions/keras_AI_surrogate.py | 1 + src/Chemistry/SurrogateModels/AI_functions.cpp | 9 +++------ 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/Chemistry/SurrogateModels/AI_Python_functions/keras_AI_surrogate.py b/src/Chemistry/SurrogateModels/AI_Python_functions/keras_AI_surrogate.py index b311582c5..1091bc766 100644 --- a/src/Chemistry/SurrogateModels/AI_Python_functions/keras_AI_surrogate.py +++ b/src/Chemistry/SurrogateModels/AI_Python_functions/keras_AI_surrogate.py @@ -5,6 +5,7 @@ import os def initiate_model(model_file_path, cuda_dir): os.environ["TF_XLA_FLAGS"] = "--tf_xla_cpu_global_jit" os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=" + cuda_dir + print("AI: Model loaded from: " + model_file_path, flush=True) model = tf.keras.models.load_model(model_file_path) return model diff --git a/src/Chemistry/SurrogateModels/AI_functions.cpp b/src/Chemistry/SurrogateModels/AI_functions.cpp index 2f28592bf..8874e037c 100644 --- a/src/Chemistry/SurrogateModels/AI_functions.cpp +++ b/src/Chemistry/SurrogateModels/AI_functions.cpp @@ -310,16 +310,11 @@ void parallel_training(EigenModel* Eigen_model, // wait for a signal on training_data_buffer_full but starts the next round immediately. std::unique_lock lock(*training_data_buffer_mutex); training_data_buffer_full->wait(lock, [start_training] { return *start_training;}); - - //hier nochmal training_data_buffer_mutex lock/lock test? - // Return if program is about to end if (*end_training) { return; } - // Reset the waiting predicate - *start_training = false; // Get the necessary training data std::cout << "AI: Training thread: Getting training data" << std::endl; // Initialize training data input and targets @@ -342,6 +337,8 @@ void parallel_training(EigenModel* Eigen_model, training_data_buffer->y[col].erase(training_data_buffer->y[col].begin(), training_data_buffer->y[col].begin() + params.training_data_size); } + // Set the waiting predicate to false if buffer is below threshold + *start_training = training_data_buffer->y[0].size() >= params.training_data_size; //update number of training runs training_data_buffer->n_training_runs += 1; // Unlock the training_data_buffer_mutex @@ -366,7 +363,7 @@ void parallel_training(EigenModel* Eigen_model, // Release the Python GIL PyGILState_Release(gstate); std::cout << "AI: Training thread: Finished training, waiting for new data" << std::endl; - } + } } std::thread python_train_thread;