From be5b42392a89c6b72298a4e2766e530a17a686a3 Mon Sep 17 00:00:00 2001 From: rastogi Date: Tue, 9 Dec 2025 12:48:48 +0100 Subject: [PATCH] feat: interval limit between rollbacks --- bin/dolo_fgcs_3_rt.R | 4 +- bin/run_poet.sh | 8 +-- src/Control/ControlModule.cpp | 80 +++++++++++++++--------- src/Control/ControlModule.hpp | 30 +++++---- src/poet.cpp | 114 ++++++++++++++-------------------- src/poet.hpp.in | 1 + 6 files changed, 117 insertions(+), 120 deletions(-) diff --git a/bin/dolo_fgcs_3_rt.R b/bin/dolo_fgcs_3_rt.R index b35c395d1..3359c5046 100644 --- a/bin/dolo_fgcs_3_rt.R +++ b/bin/dolo_fgcs_3_rt.R @@ -6,6 +6,7 @@ mape_threshold <- rep(0.0035, 13) mape_threshold[5] <- 1 #Charge zero_abs <- 1e-13 rb_limit <- 3 +rb_interval_limit <- 100 #ctrl_cell_ids <- seq(0, (400*400)/2 - 1, by = 401) #out_save <- seq(500, iterations, by = 500) @@ -20,5 +21,6 @@ list( ctrl_interval = ctrl_interval, mape_threshold = mape_threshold, zero_abs = zero_abs, - rb_limit = rb_limit + rb_limit = rb_limit, + rb_interval_limit = rb_interval_limit ) \ No newline at end of file diff --git a/bin/run_poet.sh b/bin/run_poet.sh index 5f256fbb2..95ca9b936 100644 --- a/bin/run_poet.sh +++ b/bin/run_poet.sh @@ -1,7 +1,7 @@ #!/bin/bash -#SBATCH --job-name=p1_eps0035_v2 -#SBATCH --output=p1_eps0035_v2_%j.out -#SBATCH --error=p1_eps0035_v2_%j.err +#SBATCH --job-name=p1_eps0035_r1 +#SBATCH --output=p1_eps0035_r2_%j.out +#SBATCH --error=p1_eps0035_r2_%j.err #SBATCH --partition=long #SBATCH --nodes=6 #SBATCH --ntasks-per-node=24 @@ -15,5 +15,5 @@ module purge module load cmake gcc openmpi #mpirun -n 144 ./poet dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_only_pqc -mpirun -n 144 ./poet --interp dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 p1_eps0035_v2 +mpirun -n 144 ./poet --interp dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 p1_eps0035_r2 #mpirun -n 144 ./poet --interp barite_fgcs_4_new/barite_fgcs_4_new_rt.R barite_fgcs_4_new/barite_fgcs_4_new.qs2 barite \ No newline at end of file diff --git a/src/Control/ControlModule.cpp b/src/Control/ControlModule.cpp index ebdae31ae..eb2dac888 100644 --- a/src/Control/ControlModule.cpp +++ b/src/Control/ControlModule.cpp @@ -33,7 +33,7 @@ void poet::ControlModule::beginIteration(const uint32_t &iter, const bool &dht_e void poet::ControlModule::updateSurrState(bool dht_enabled, bool interp_enabled) { bool in_warmup = (global_iter <= config.ctrl_interval); - bool rb_limit_reached = (rb_count >= config.rb_limit); + bool rb_limit_reached = rbLimitReached(); if (rb_enabled && stab_countdown > 0 && !rb_limit_reached) { --stab_countdown; @@ -54,9 +54,18 @@ void poet::ControlModule::updateSurrState(bool dht_enabled, bool interp_enabled) } else { std::cout << "In stabilization phase." << std::endl; } - return; } + + if (rb_count > 0 && !rb_enabled && !in_warmup) { + surr_active++; + if (surr_active > config.rb_interval_limit) { + surr_active = 0; + rb_count -= 1; + std::cout << "Rollback count reset to: " << rb_count << "." << std::endl; + } + } + /* enable user-requested surrogates */ chem->SetStabEnabled(false); chem->SetDhtEnabled(dht_enabled); @@ -80,7 +89,8 @@ void poet::ControlModule::readCheckpoint(uint32_t ¤t_iter, uint32_t rollba double r_check_a, r_check_b; r_check_a = MPI_Wtime(); Checkpoint_s checkpoint_read{.field = chem->getField()}; - read_checkpoint(out_dir, "checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read); + read_checkpoint(out_dir, "checkpoint" + std::to_string(rollback_iter) + ".hdf5", + checkpoint_read); current_iter = checkpoint_read.iteration; r_check_b = MPI_Wtime(); r_check_t += r_check_b - r_check_a; @@ -102,20 +112,22 @@ void poet::ControlModule::writeMetrics(const std::string &out_dir, uint32_t poet::ControlModule::calcRbIter() { - uint32_t last_iter = ((global_iter - 1) / config.chkpt_interval) * config.chkpt_interval; + uint32_t last_iter = + ((global_iter - 1) / config.chkpt_interval) * config.chkpt_interval; uint32_t rb_iter = (last_iter <= last_chkpt_written) ? last_iter : last_chkpt_written; return rb_iter; } -std::optional poet::ControlModule::findRbTarget(const std::vector &species) { +std::optional +poet::ControlModule::findRbTarget(const std::vector &species) { if (metrics_history.empty()) { std::cout << "No error history yet, skipping rollback check." << std::endl; flush_request = false; return std::nullopt; } - if (rb_count > config.rb_limit) { + if (rbLimitReached()) { std::cout << "Rollback limit reached, skipping control logic." << std::endl; flush_request = false; return std::nullopt; @@ -126,14 +138,19 @@ std::optional poet::ControlModule::findRbTarget(const std::vector config.mape_threshold[i]) { - std::cout << "Species " << species[i] << " MAPE=" << mape[i] - << " threshold=" << config.mape_threshold[i] << std::endl; + /* skip Charge */ + if (sp_idx == 4) { + continue; + } + + if (mape[sp_idx] > config.mape_threshold[sp_idx]) { + std::cout << "Species " << species[sp_idx] << " MAPE=" << mape[sp_idx] + << " threshold=" << config.mape_threshold[sp_idx] << std::endl; if (last_chkpt_written == 0) { std::cout << " Threshold exceeded but no checkpoint exists yet." << std::endl; @@ -141,9 +158,10 @@ std::optional poet::ControlModule::findRbTarget(const std::vector config.rb_limit) { - return false; - } - if (global_iter == 1 || global_iter % config.ctrl_interval == 1) { - return true; - } - return false; + return (config.rb_limit > 0) && !rbLimitReached(); +} + +inline bool poet::ControlModule::rbLimitReached() const { + /* rollback is completly disabled */ + if (config.rb_limit == 0) + return false; + return rb_count >= config.rb_limit; } diff --git a/src/Control/ControlModule.hpp b/src/Control/ControlModule.hpp index ea8222c7e..db1f148da 100644 --- a/src/Control/ControlModule.hpp +++ b/src/Control/ControlModule.hpp @@ -17,6 +17,7 @@ struct ControlConfig { uint32_t ctrl_interval = 0; uint32_t chkpt_interval = 0; uint32_t rb_limit = 0; + uint32_t rb_interval_limit = 0; double zero_abs = 0.0; std::vector mape_threshold; }; @@ -28,8 +29,7 @@ struct SpeciesMetrics { uint32_t rb_count = 0; SpeciesMetrics(uint32_t n_species, uint32_t iter, uint32_t count) - : mape(n_species, 0.0), rrmse(n_species, 0.0), iteration(iter), - rb_count(count) {} + : mape(n_species, 0.0), rrmse(n_species, 0.0), iteration(iter), rb_count(count) {} }; class ControlModule { @@ -41,27 +41,22 @@ public: void writeCheckpoint(uint32_t &iter, const std::string &out_dir); - void writeMetrics(const std::string &out_dir, - const std::vector &species); + void writeMetrics(const std::string &out_dir, const std::vector &species); std::optional findRbTarget(); void computeMetrics(const std::vector &reference_values, - const std::vector &surrogate_values, - const uint32_t size_per_prop, - const std::vector &species); + const std::vector &surrogate_values, + const uint32_t size_per_prop, + const std::vector &species); - void processCheckpoint(uint32_t ¤t_iter, - const std::string &out_dir, + void processCheckpoint(uint32_t ¤t_iter, const std::string &out_dir, const std::vector &species); - std::optional - findRbTarget(const std::vector &species); + std::optional findRbTarget(const std::vector &species); bool needsFlagBcast() const; - bool isCtrlIntervalActive() const { - return this->ctrl_active; - } + bool isCtrlIntervalActive() const { return this->ctrl_active; } bool getFlushRequest() const { return flush_request; } void clearFlushRequest() { flush_request = false; } @@ -76,17 +71,20 @@ public: private: void updateSurrState(bool dht_enabled, bool interp_enabled); - void readCheckpoint(uint32_t ¤t_iter, - uint32_t rollback_iter, const std::string &out_dir); + void readCheckpoint(uint32_t ¤t_iter, uint32_t rollback_iter, + const std::string &out_dir); uint32_t calcRbIter(); + inline bool rbLimitReached() const; + ControlConfig config; ChemistryModule *chem = nullptr; std::uint32_t global_iter = 0; std::uint32_t rb_count = 0; std::uint32_t stab_countdown = 0; + std::uint32_t surr_active = 0; std::uint32_t last_chkpt_written = 0; bool rb_enabled = false; diff --git a/src/poet.cpp b/src/poet.cpp index 3db547fc2..88694b9b6 100644 --- a/src/poet.cpp +++ b/src/poet.cpp @@ -99,9 +99,8 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) { "Print progress bar during chemical simulation"); /*Parse work package size*/ - app.add_option( - "-w,--work-package-size", params.work_package_size, - "Work package size to distribute to each worker for chemistry module") + app.add_option("-w,--work-package-size", params.work_package_size, + "Work package size to distribute to each worker for chemistry module") ->check(CLI::PositiveNumber) ->default_val(RuntimeParameters::WORK_PACKAGE_SIZE_DEFAULT); @@ -112,9 +111,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) { // cout << "CPP: DHT is " << ( dht_enabled ? "ON" : "OFF" ) << '\n'; - dht_group - ->add_option("--dht-size", params.dht_size, - "DHT size per process in Megabyte") + dht_group->add_option("--dht-size", params.dht_size, "DHT size per process in Megabyte") ->check(CLI::PositiveNumber) ->default_val(RuntimeParameters::DHT_SIZE_DEFAULT); // cout << "CPP: DHT size per process (Byte) = " << dht_size_per_process << @@ -140,9 +137,8 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) { ->check(CLI::PositiveNumber) ->default_val(RuntimeParameters::INTERP_MIN_ENTRIES_DEFAULT); interp_group - ->add_option( - "--interp-bucket-entries", params.interp_bucket_entries, - "Maximum number of entries in each bucket of the interpolation table") + ->add_option("--interp-bucket-entries", params.interp_bucket_entries, + "Maximum number of entries in each bucket of the interpolation table") ->check(CLI::PositiveNumber) ->default_val(RuntimeParameters::INTERP_BUCKET_ENTRIES_DEFAULT); @@ -152,25 +148,21 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) { app.add_flag("--rds", params.as_rds, "Save output as .rds file instead of default .qs2"); - app.add_flag("--qs", params.as_qs, - "Save output as .qs file instead of default .qs2"); + app.add_flag("--qs", params.as_qs, "Save output as .qs file instead of default .qs2"); std::string init_file; std::string runtime_file; - app.add_option("runtime_file", runtime_file, - "Runtime R script defining the simulation") + app.add_option("runtime_file", runtime_file, "Runtime R script defining the simulation") ->required() ->check(CLI::ExistingFile); - app.add_option( - "init_file", init_file, - "Initial R script defining the simulation, produced by poet_init") + app.add_option("init_file", init_file, + "Initial R script defining the simulation, produced by poet_init") ->required() ->check(CLI::ExistingFile); - app.add_option("out_dir", params.out_dir, - "Output directory of the simulation") + app.add_option("out_dir", params.out_dir, "Output directory of the simulation") ->required(); try { @@ -202,8 +194,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) { // << simparams.dht_significant_digits); // MSG("DHT logarithm before rounding: " // << (simparams.dht_log ? "ON" : "OFF")); - MSG("DHT size per process (Megabyte) = " + - std::to_string(params.dht_size)); + MSG("DHT size per process (Megabyte) = " + std::to_string(params.dht_size)); MSG("DHT save snapshots is " + BOOL_PRINT(params.dht_snaps)); // MSG("DHT load file is " + chem_params.dht_file); } @@ -212,8 +203,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) { MSG("PHT interpolation enabled: " + BOOL_PRINT(params.use_interp)); MSG("PHT interp-size = " + std::to_string(params.interp_size)); MSG("PHT interp-min = " + std::to_string(params.interp_min_entries)); - MSG("PHT interp-bucket-entries = " + - std::to_string(params.interp_bucket_entries)); + MSG("PHT interp-bucket-entries = " + std::to_string(params.interp_bucket_entries)); } } // chem_params.dht_outdir = out_dir; @@ -253,10 +243,11 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) { Rcpp::as(global_rt_setup->operator[]("ctrl_interval")); params.chkpt_interval = Rcpp::as(global_rt_setup->operator[]("chkpt_interval")); - params.rb_limit = - Rcpp::as(global_rt_setup->operator[]("rb_limit")); - params.mape_threshold = Rcpp::as>( - global_rt_setup->operator[]("mape_threshold")); + params.rb_limit = Rcpp::as(global_rt_setup->operator[]("rb_limit")); + params.rb_interval_limit = + Rcpp::as(global_rt_setup->operator[]("rb_interval_limit")); + params.mape_threshold = + Rcpp::as>(global_rt_setup->operator[]("mape_threshold")); params.zero_abs = Rcpp::as(global_rt_setup->operator[]("zero_abs")); } catch (const std::exception &e) { ERRMSG("Error while parsing R scripts: " + std::string(e.what())); @@ -278,16 +269,15 @@ void call_master_iter_end(RInside &R, const Field &trans, const Field &chem) { R["TMP"] = Rcpp::wrap(chem.AsVector()); R["TMP_PROPS"] = Rcpp::wrap(chem.GetProps()); R.parseEval(std::string("state_C <- setNames(data.frame(matrix(TMP, nrow=" + - std::to_string(chem.GetRequestedVecSize()) + - ")), TMP_PROPS)")); + std::to_string(chem.GetRequestedVecSize()) + ")), TMP_PROPS)")); R["setup"] = *global_rt_setup; R.parseEval("setup <- master_iteration_end(setup, state_T, state_C)"); *global_rt_setup = R["setup"]; } static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, - DiffusionModule &diffusion, - ChemistryModule &chem, ControlModule &control) { + DiffusionModule &diffusion, ChemistryModule &chem, + ControlModule &control) { /* Iteration Count is dynamic, retrieving value from R (is only needed by * master for the following loop) */ @@ -327,10 +317,9 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, double ai_start_t = MPI_Wtime(); // Save current values from the tug field as predictor for the ai step R["TMP"] = Rcpp::wrap(chem.getField().AsVector()); - R.parseEval( - std::string("predictors <- setNames(data.frame(matrix(TMP, nrow=" + - std::to_string(chem.getField().GetRequestedVecSize()) + - ")), TMP_PROPS)")); + R.parseEval(std::string("predictors <- setNames(data.frame(matrix(TMP, nrow=" + + std::to_string(chem.getField().GetRequestedVecSize()) + + ")), TMP_PROPS)")); R.parseEval("predictors <- predictors[ai_surrogate_species]"); // Apply preprocessing @@ -339,8 +328,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, // Predict MSG("AI Prediction"); - R.parseEval( - "aipreds_scaled <- prediction_step(model, predictors_scaled)"); + R.parseEval("aipreds_scaled <- prediction_step(model, predictors_scaled)"); // Apply postprocessing MSG("AI Postprocessing"); @@ -348,8 +336,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, // Validate prediction and write valid predictions to chem field MSG("AI Validation"); - R.parseEval( - "validity_vector <- validate_predictions(predictors, aipreds)"); + R.parseEval("validity_vector <- validate_predictions(predictors, aipreds)"); MSG("AI Marking accepted"); chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector")); @@ -361,9 +348,8 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, validity_vector)"); MSG("AI Set Field"); - Field predictions_field = - Field(R.parseEval("nrow(predictors)"), RTempField, - R.parseEval("colnames(predictors)")); + Field predictions_field = Field(R.parseEval("nrow(predictors)"), RTempField, + R.parseEval("colnames(predictors)")); MSG("AI Update"); chem.getField().update(predictions_field); @@ -378,10 +364,9 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, double ai_start_t = MPI_Wtime(); R["TMP"] = Rcpp::wrap(chem.getField().AsVector()); - R.parseEval( - std::string("targets <- setNames(data.frame(matrix(TMP, nrow=" + - std::to_string(chem.getField().GetRequestedVecSize()) + - ")), TMP_PROPS)")); + R.parseEval(std::string("targets <- setNames(data.frame(matrix(TMP, nrow=" + + std::to_string(chem.getField().GetRequestedVecSize()) + + ")), TMP_PROPS)")); R.parseEval("targets <- targets[ai_surrogate_species]"); // TODO: Check how to get the correct columns @@ -414,8 +399,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, std::to_string(maxiter)); if (control.isCtrlIntervalActive()) { - control.processCheckpoint(iter, params.out_dir, - chem.getField().GetProps()); + control.processCheckpoint(iter, params.out_dir, chem.getField().GetProps()); control.writeMetrics(params.out_dir, chem.getField().GetProps()); } // MSG(); @@ -452,16 +436,12 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, //} if (params.use_interp) { - chem_profiling["interp_w"] = - Rcpp::wrap(chem.GetWorkerInterpolationWriteTimings()); - chem_profiling["interp_r"] = - Rcpp::wrap(chem.GetWorkerInterpolationReadTimings()); - chem_profiling["interp_g"] = - Rcpp::wrap(chem.GetWorkerInterpolationGatherTimings()); + chem_profiling["interp_w"] = Rcpp::wrap(chem.GetWorkerInterpolationWriteTimings()); + chem_profiling["interp_r"] = Rcpp::wrap(chem.GetWorkerInterpolationReadTimings()); + chem_profiling["interp_g"] = Rcpp::wrap(chem.GetWorkerInterpolationGatherTimings()); chem_profiling["interp_fc"] = Rcpp::wrap(chem.GetWorkerInterpolationFunctionCallTimings()); - chem_profiling["interp_calls"] = - Rcpp::wrap(chem.GetWorkerInterpolationCalls()); + chem_profiling["interp_calls"] = Rcpp::wrap(chem.GetWorkerInterpolationCalls()); chem_profiling["interp_cached"] = Rcpp::wrap(chem.GetWorkerPHTCacheHits()); } @@ -476,8 +456,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, return profiling; } -std::vector getSpeciesNames(const Field &&field, int root, - MPI_Comm comm) { +std::vector getSpeciesNames(const Field &&field, int root, MPI_Comm comm) { std::uint32_t n_elements; std::uint32_t n_string_size; @@ -494,8 +473,8 @@ std::vector getSpeciesNames(const Field &&field, int root, for (std::uint32_t i = 0; i < n_elements; i++) { n_string_size = field.GetProps()[i].size(); MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD); - MPI_Bcast(const_cast(field.GetProps()[i].c_str()), n_string_size, - MPI_CHAR, root, MPI_COMM_WORLD); + MPI_Bcast(const_cast(field.GetProps()[i].c_str()), n_string_size, MPI_CHAR, + root, MPI_COMM_WORLD); } return field.GetProps(); @@ -609,8 +588,8 @@ int main(int argc, char *argv[]) { MPI_Barrier(MPI_COMM_WORLD); - ChemistryModule chemistry(run_params.work_package_size, - init_list.getChemistryInit(), MPI_COMM_WORLD); + ChemistryModule chemistry(run_params.work_package_size, init_list.getChemistryInit(), + MPI_COMM_WORLD); // ControlModule control; // chemistry.SetControlModule(&control); @@ -633,8 +612,8 @@ int main(int argc, char *argv[]) { chemistry.masterEnableSurrogates(surr_setup); ControlConfig config(run_params.ctrl_interval, run_params.chkpt_interval, - run_params.rb_limit, run_params.zero_abs, - run_params.mape_threshold); + run_params.rb_limit, run_params.rb_interval_limit, + run_params.zero_abs, run_params.mape_threshold); ControlModule control(config, &chemistry); @@ -660,8 +639,7 @@ int main(int argc, char *argv[]) { /* Incorporate ai surrogate from R */ R.parseEvalQ(ai_surrogate_r_library); /* Use dht species for model input and output */ - R["ai_surrogate_species"] = - init_list.getChemistryInit().dht_species.getNames(); + R["ai_surrogate_species"] = init_list.getChemistryInit().dht_species.getNames(); const std::string ai_surrogate_input_script = init_list.getChemistryInit().ai_surrogate_input_script; @@ -678,13 +656,11 @@ int main(int argc, char *argv[]) { // MPI_Barrier(MPI_COMM_WORLD); - DiffusionModule diffusion(init_list.getDiffusionInit(), - init_list.getInitialGrid()); + DiffusionModule diffusion(init_list.getDiffusionInit(), init_list.getInitialGrid()); chemistry.masterSetField(init_list.getInitialGrid()); - Rcpp::List profiling = - RunMasterLoop(R, run_params, diffusion, chemistry, control); + Rcpp::List profiling = RunMasterLoop(R, run_params, diffusion, chemistry, control); MSG("finished simulation loop"); diff --git a/src/poet.hpp.in b/src/poet.hpp.in index 678aaafbc..50d0ea77b 100644 --- a/src/poet.hpp.in +++ b/src/poet.hpp.in @@ -54,6 +54,7 @@ struct RuntimeParameters { std::uint32_t chkpt_interval = 0; std::uint32_t ctrl_interval = 0; std::uint32_t rb_limit = 0; + std::uint32_t rb_interval_limit = 0; std::vector mape_threshold; double zero_abs = 0.0;