From c7d1fc152c3b9a66b7a24b0444ab068765307a72 Mon Sep 17 00:00:00 2001 From: Hannes Signer Date: Wed, 10 Dec 2025 19:54:32 +0100 Subject: [PATCH] run retraining only once --- src/poet.cpp | 20 +++++++++++++------- src/poet.hpp.in | 3 ++- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/poet.cpp b/src/poet.cpp index 9d89f18a7..4520b3d5d 100644 --- a/src/poet.cpp +++ b/src/poet.cpp @@ -310,6 +310,8 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms, R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps()); std::unique_ptr ai_ctx = nullptr; + size_t retrain_counter = 0; + size_t field_size = 0; if (params.ai) { @@ -329,6 +331,9 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms, std::vector mean = R["mean"]; std::vector scale = R["scale"]; + field_size = chem.getField().GetRequestedVecSize(); + std::cout << field_size << std::endl; + ai_ctx->scaler.set_scaler(mean, scale); // initialzie training backens only if retraining is desired @@ -338,7 +343,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms, } else if (params.ai_backend == NAA_BACKEND) { MSG("AI Surrogate with NAA backend enabled.") ai_ctx->training_backend = - std::make_unique>(4 * params.batch_size); + std::make_unique>(4 * field_size); } if (!params.disable_retraining) { @@ -535,15 +540,15 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms, MSG("AI: add invalid data to buffer"); - ai_ctx->data_semaphore_write.acquire(); - std::cout << "size of predictors " << predictors_retraining[0].size() << std::endl; std::cout << "size of targets " << targets_retraining[0].size() << std::endl; + ai_ctx->data_semaphore_write.acquire(); + if (predictors_retraining[0].size() > 0 && - targets_retraining[0].size() > 0) { + targets_retraining[0].size() > 0 && retrain_counter == 0) { ai_ctx->design_buffer.addData(predictors_retraining); ai_ctx->results_buffer.addData(targets_retraining); } @@ -559,10 +564,11 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms, std::cout << "results_buffer_size: " << elements_results_buffer << std::endl; - if (elements_design_buffer >= 4 * params.batch_size && - elements_results_buffer >= 4 * params.batch_size && - ai_ctx->training_is_running == false) { + if (elements_design_buffer >= 4 * field_size && + elements_results_buffer >= 4 * field_size && + ai_ctx->training_is_running == false && retrain_counter == 0) { ai_ctx->data_semaphore_read.release(); + retrain_counter++; } else { ai_ctx->data_semaphore_write.release(); } diff --git a/src/poet.hpp.in b/src/poet.hpp.in index 4b1c3c71f..24335a663 100644 --- a/src/poet.hpp.in +++ b/src/poet.hpp.in @@ -23,6 +23,7 @@ #pragma once #include +#include #include #include #include @@ -86,7 +87,7 @@ struct RuntimeParameters { static constexpr std::uint8_t AI_BACKEND_DEFAULT = 1; std::uint8_t ai_backend = AI_BACKEND_DEFAULT; // 1 - python, 2 - naa bool train_only_invalid = true; - int batch_size = 200 * 200; + int batch_size = 2500; static constexpr std::uint8_t DEFAULT_FUNCTION_CODE = 0; std::uint8_t function_code = DEFAULT_FUNCTION_CODE;