diff --git a/bin/dolo_fgcs_3.R b/bin/dolo_fgcs_3.R index 6c43174c3..004ea4328 100644 --- a/bin/dolo_fgcs_3.R +++ b/bin/dolo_fgcs_3.R @@ -115,10 +115,10 @@ setup <- list( Chemistry = chemistry_setup # Parameters related to the chemistry process ) -iterations <- 500 +iterations <- 100 dt <- 200 -checkpoint_interval <- 100 -control_interval <- 100 +checkpoint_interval <- 20 +control_interval <- 20 mape_threshold <- rep(3.5e-3, 13) #out_save <- seq(50, iterations, by = 50) diff --git a/bin/run_poet.sh b/bin/run_poet.sh index e60f82823..b96c8efa1 100644 --- a/bin/run_poet.sh +++ b/bin/run_poet.sh @@ -1,7 +1,7 @@ #!/bin/bash -#SBATCH --job-name=bar_fgcs_500_eps -#SBATCH --output=bar_fgcs_500_eps_%j.out -#SBATCH --error=bar_fgcs_500_eps_%j.err +#SBATCH --job-name=dolo_warmup_debug +#SBATCH --output=dolo_warmup_debug%j.out +#SBATCH --error=dolo_warmup_debug%j.err #SBATCH --partition=long #SBATCH --nodes=4 #SBATCH --ntasks=96 @@ -14,5 +14,5 @@ source /etc/profile.d/modules.sh module purge module load cmake gcc openmpi -#mpirun -n 96 ./poet --interp dolo_fgcs_3.R dolo_fgcs_3.qs2 poet_dolo_fgcs_500_eps -mpirun -n 96 ./poet --interp barite_fgcs_2.R barite_fgcs_2.qs2 bar_fgcs_500_eps \ No newline at end of file +mpirun -n 96 ./poet --interp dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_warmup_debug +#mpirun -n 96 ./poet --interp barite_fgcs_2.R barite_fgcs_2.qs2 bar_fgcs_500_eps \ No newline at end of file diff --git a/share/poet/barite/barite_het.qs2 b/share/poet/barite/barite_het.qs2 index 4c3f004db..e6e04c46d 100644 Binary files a/share/poet/barite/barite_het.qs2 and b/share/poet/barite/barite_het.qs2 differ diff --git a/share/poet/surfex/PoetEGU_surfex_500.qs2 b/share/poet/surfex/PoetEGU_surfex_500.qs2 index 03f099846..d0bf6eb27 100644 Binary files a/share/poet/surfex/PoetEGU_surfex_500.qs2 and b/share/poet/surfex/PoetEGU_surfex_500.qs2 differ diff --git a/src/Chemistry/ChemistryModule.hpp b/src/Chemistry/ChemistryModule.hpp index 8a96084c5..2ecd97fdc 100644 --- a/src/Chemistry/ChemistryModule.hpp +++ b/src/Chemistry/ChemistryModule.hpp @@ -3,10 +3,9 @@ #define CHEMISTRYMODULE_H_ #include "ChemistryDefs.hpp" +#include "Control/ControlModule.hpp" #include "DataStructures/Field.hpp" #include "DataStructures/NamedVector.hpp" -#include "ChemistryDefs.hpp" -#include "Control/ControlModule.hpp" #include "Init/InitialList.hpp" #include "NameDouble.h" #include "PhreeqcRunner.hpp" @@ -22,7 +21,7 @@ #include namespace poet { - class ControlModule; +class ControlModule; /** * \brief Wrapper around PhreeqcRM to provide POET specific parallelization with * easy access. @@ -252,7 +251,15 @@ public: std::vector ai_surrogate_validity_vector; - void setControlModule(poet::ControlModule *ctrl) { control_module = ctrl; } + void SetControlModule(poet::ControlModule *ctrl) { control_module = ctrl; } + + void SetDhtEnabled(bool enabled) { dht_enabled = enabled; } + bool GetDhtEnabled() const { return dht_enabled; } + + void SetInterpEnabled(bool enabled) { interp_enabled = enabled; } + bool GetInterpEnabled() const { return interp_enabled; } + + void SetWarmupEnabled(bool enabled) { warmup_enabled = enabled; } protected: void initializeDHT(uint32_t size_mb, @@ -267,13 +274,13 @@ protected: enum { CHEM_FIELD_INIT, - CHEM_DHT_ENABLE, + //CHEM_DHT_ENABLE, CHEM_DHT_SIGNIF_VEC, CHEM_DHT_SNAPS, CHEM_DHT_READ_FILE, - //CHEM_IP, // Control flag - CHEM_CTRL, // Control flag - CHEM_IP_ENABLE, + //CHEM_WARMUP_PHASE, // Control flag + //CHEM_CTRL_ENABLE, // Control flag + //CHEM_IP_ENABLE, CHEM_IP_MIN_ENTRIES, CHEM_IP_SIGNIF_VEC, CHEM_WORK_LOOP, @@ -387,7 +394,7 @@ protected: bool ai_surrogate_enabled{false}; - static constexpr uint32_t BUFFER_OFFSET = 5; + static constexpr uint32_t BUFFER_OFFSET = 6; inline void ChemBCast(void *buf, int count, MPI_Datatype datatype) const { MPI_Bcast(buf, count, datatype, 0, this->group_comm); @@ -396,6 +403,9 @@ protected: inline void PropagateFunctionType(int &type) const { ChemBCast(&type, 1, MPI_INT); } + + void PropagateControlLogic(int type, int flag); + double simtime = 0.; double idle_t = 0.; double seq_t = 0.; @@ -422,6 +432,7 @@ protected: poet::ControlModule *control_module = nullptr; bool control_enabled{false}; + bool warmup_enabled{false}; // std::vector sur_shuffled; }; diff --git a/src/Chemistry/MasterFunctions.cpp b/src/Chemistry/MasterFunctions.cpp index a68821a8b..6778736db 100644 --- a/src/Chemistry/MasterFunctions.cpp +++ b/src/Chemistry/MasterFunctions.cpp @@ -232,6 +232,37 @@ inline void printProgressbar(int count_pkgs, int n_wp, int barWidth = 70) { /* end visual progress */ } +void poet::ChemistryModule::PropagateControlLogic(int type, int flag) { + /* + PropagateFunctionType(type); + + static int master_bcast_seq = 0; + int tmp = flag ? 1 : 0; + std::cerr << "[MASTER BCAST " << master_bcast_seq << "] ftype=" << type + << " flag=" << tmp << std::endl + << std::flush; + master_bcast_seq++; + ChemBCast(&tmp, 1, MPI_INT); + + switch (type) { + case CHEM_CTRL_ENABLE: + this->control_enabled = (tmp == 1); + break; + case CHEM_WARMUP_PHASE: + this->warmup_enabled = (tmp == 1); + break; + case CHEM_DHT_ENABLE: + this->dht_enabled = (tmp == 1); + break; + case CHEM_IP_ENABLE: + this->interp_enabled = (tmp == 1); + break; + default: + break; + } + */ +} + inline void poet::ChemistryModule::MasterSendPkgs( worker_list_t &w_list, workpointer_t &work_pointer, workpointer_t &sur_pointer, int &pkg_to_send, int &count_pkgs, @@ -250,6 +281,10 @@ inline void poet::ChemistryModule::MasterSendPkgs( local_work_package_size = (int)wp_sizes_vector[count_pkgs]; count_pkgs++; + uint32_t wp_start_index = + std::accumulate(wp_sizes_vector.begin(), + std::next(wp_sizes_vector.begin(), count_pkgs), 0); + /* note current processed work package in workerlist */ w_list[p].send_addr = work_pointer.base(); w_list[p].surrogate_addr = sur_pointer.base(); @@ -272,10 +307,12 @@ inline void poet::ChemistryModule::MasterSendPkgs( // current time of simulation (age) in seconds send_buffer[end_of_wp + 3] = this->simtime; // current work package start location in field - uint32_t wp_start_index = - std::accumulate(wp_sizes_vector.begin(), - std::next(wp_sizes_vector.begin(), count_pkgs), 0); send_buffer[end_of_wp + 4] = wp_start_index; + // control flags (bitmask) + int flags = (this->interp_enabled ? 1 : 0) | (this->dht_enabled ? 2 : 0) | + (this->warmup_enabled ? 4 : 0) | + (this->control_enabled ? 8 : 0); + send_buffer[end_of_wp + 5] = static_cast(flags); /* ATTENTION Worker p has rank p+1 */ // MPI_Send(send_buffer, end_of_wp + BUFFER_OFFSET, MPI_DOUBLE, p + 1, @@ -427,18 +464,35 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { MPI_INT); } - uint32_t control_flag = control_module->GetControlIntervalEnabled(); - if (control_flag) { - ftype = CHEM_CTRL; - PropagateFunctionType(ftype); - ChemBCast(&control_flag, 1, MPI_INT); + // ftype = CHEM_IP_ENABLE; + // ftype = CHEM_WARMUP_PHASE; + /* + PropagateFunctionType(ftype); + int warmup_flag = this->warmup_enabled ? 1 : 0; + if (warmup_flag) { + this->interp_enabled = false; + int interp_flag = 0; + ChemBCast(&interp_flag, 1, MPI_INT); + + // PropagateControlLogic(CHEM_WARMUP_PHASE, 1); + // PropagateControlLogic(CHEM_DHT_ENABLE, 0); + // PropagateControlLogic(CHEM_IP_ENABLE, 0); + } else { + this->interp_enabled = true; + int interp_flag = 1; + ChemBCast(&interp_flag, 1, MPI_INT); + + // PropagateControlLogic(CHEM_WARMUP_PHASE, 0); + // PropagateControlLogic(CHEM_DHT_ENABLE, 1); + // PropagateControlLogic(CHEM_IP_ENABLE, 1); } - /* - ftype = CHEM_IP; - PropagateFunctionType(ftype); - ctrl_module->BCastControlFlags(); -*/ + int control_flag = this->control_module->GetControlIntervalEnabled() ? 1 : 0; + if (control_flag) { + PropagateControlLogic(CHEM_CTRL_ENABLE, control_flag); + } + */ + ftype = CHEM_WORK_LOOP; PropagateFunctionType(ftype); @@ -455,8 +509,8 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count, wp_sizes_vector.size()); - std::vector mpi_surr_buffer; - mpi_surr_buffer.resize(mpi_buffer.size()); + control_enabled = this->control_module->GetControlIntervalEnabled() ? 1 : 0; + std::vector mpi_surr_buffer{mpi_buffer}; /* setup local variables */ pkg_to_send = wp_sizes_vector.size(); @@ -511,15 +565,48 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { chem_field = out_vec; /* do master stuff */ - - if (control_flag) { + if (control_enabled) { std::cout << "[Master] Control logic enabled for this iteration." << std::endl; std::vector sur_unshuffled{mpi_surr_buffer}; unshuffleField(mpi_surr_buffer, this->n_cells, this->prop_count, wp_sizes_vector.size(), sur_unshuffled); - control_module->computeSpeciesErrors(out_vec, sur_unshuffled, + // Quick debug: compare out_vec vs sur_unshuffled + size_t N = out_vec.size(); + if (N != sur_unshuffled.size()) { + std::cerr << "[MASTER DBG] size mismatch out_vec=" << N + << " sur_unshuffled=" << sur_unshuffled.size() << std::endl; + } /*else { + double max_abs = 0.0; + double max_rel = 0.0; + size_t worst_i = 0; + for (size_t i = 0; i < N; i) { + double a = out_vec[i]; + double b = sur_unshuffled[i]; + double absd = std::fabs(a - b); + if (absd > max_abs) { + max_abs = absd; + worst_i = i; + } + double rel = (std::fabs(a) > 1e-12) ? absd / std::fabs(a) : (absd > 0 ? + 1e12 : 0.0); if (rel > max_rel) max_rel = rel; + } + std::cerr << "[MASTER DBG] control compare N=" << N + << " max_abs=" << max_abs << " max_rel=" << max_rel + << " worst_idx=" << worst_i + << " out_vec[worst]=" << out_vec[worst_i] + << " sur[worst]=" << sur_unshuffled[worst_i] << std::endl; + // optionally print first 8 entries + std::cerr << "[MASTER DBG] out[0..7]: "; + for (size_t i = 0; i < std::min(8, N); i) std::cerr << out_vec[i] + << " "; std::cerr << "\n[MASTER DBG] sur[0..7]: "; for (size_t i = 0; i < + std::min(8, N); +i) std::cerr << sur_unshuffled[i] << " "; std::cerr + << std::endl; + } + */ + + control_module->ComputeSpeciesErrorMetrics(out_vec, sur_unshuffled, this->n_cells); } diff --git a/src/Chemistry/WorkerFunctions.cpp b/src/Chemistry/WorkerFunctions.cpp index fb7b4d380..4d19ad1f6 100644 --- a/src/Chemistry/WorkerFunctions.cpp +++ b/src/Chemistry/WorkerFunctions.cpp @@ -59,12 +59,37 @@ void poet::ChemistryModule::WorkerLoop() { MPI_INT, 0, this->group_comm); break; } - case CHEM_CTRL: { - int control_flag ; - ChemBCast(&control_flag, 1, MPI_INT); - this->control_enabled = (control_flag == 1); + /* + case CHEM_WARMUP_PHASE: { + int warmup_flag = 0; + ChemBCast(&warmup_flag, 1, MPI_INT); + this->warmup_enabled = (warmup_flag == 1); + //std::cout << "Warmup phase is " << this->warmup_enabled << std::endl; break; } + case CHEM_DHT_ENABLE: { + int dht_flag = 0; + ChemBCast(&dht_flag, 1, MPI_INT); + this->dht_enabled = (dht_flag == 1); + //std::cout << "DHT_enabled is " << this->dht_enabled << std::endl; + break; + } + case CHEM_IP_ENABLE: { + int interp_flag = 0; + ChemBCast(&interp_flag, 1, MPI_INT); + this->interp_enabled = (interp_flag == 1); + ; + std::cout << "Interp_enabled is " << this->interp_enabled << std::endl; + break; + } + case CHEM_CTRL_ENABLE: { + int control_flag = 0; + ChemBCast(&control_flag, 1, MPI_INT); + this->control_enabled = (control_flag == 1); + std::cout << "Control_enabled is " << this->control_enabled << std::endl; + break; + } + */ case CHEM_WORK_LOOP: { WorkerProcessPkgs(timings, iteration); break; @@ -136,6 +161,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, double current_sim_time; uint32_t wp_start_index; int count = double_count; + int flags; std::vector mpi_buffer(count); /* receive */ @@ -162,6 +188,19 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, // current work package start location in field wp_start_index = mpi_buffer[count + 4]; + // read packed control flags + flags = static_cast(mpi_buffer[count + 5]); + this->interp_enabled = (flags & 1) != 0; + this->dht_enabled = (flags & 2) != 0; + this->warmup_enabled = (flags & 4) != 0; + this->control_enabled = (flags & 8) != 0; + + /*std::cout << "warmup_enabled is " << warmup_enabled << ", control_enabled is + " + << control_enabled << ", dht_enabled is " + << dht_enabled << ", interp_enabled is " << interp_enabled + << std::endl;*/ + for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) { s_curr_wp.input[wp_i] = std::vector(mpi_buffer.begin() + this->prop_count * wp_i, @@ -169,7 +208,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, } // std::cout << this->comm_rank << ":" << counter++ << std::endl; - if (dht_enabled || interp_enabled) { + if (dht_enabled || interp_enabled || warmup_enabled) { dht->prepareKeys(s_curr_wp.input, dt); } @@ -203,7 +242,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) { s_curr_wp_control.output[wp_i] = std::vector(this->prop_count, 0.0); - s_curr_wp_control.mapping[wp_i] = 0; + s_curr_wp_control.mapping[wp_i] = CHEM_PQC; } } @@ -216,7 +255,6 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, if (control_enabled) { - std::size_t sur_wp_offset = s_curr_wp.size * this->prop_count; mpi_buffer.resize(count + sur_wp_offset); @@ -231,9 +269,8 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, // copy surrogate output after the the pqc output, mpi_buffer[pqc][interp] for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) { - if (s_curr_wp.mapping[wp_i] != - CHEM_PQC) // only copy if surrogate was used - { + // only copy if surrogate was used + if (s_curr_wp.mapping[wp_i] != CHEM_PQC) { std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(), mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i); } else { @@ -259,14 +296,24 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, MPI_Isend(mpi_buffer.data(), count, MPI_DOUBLE, 0, mpi_tag, MPI_COMM_WORLD, &send_req); - if (dht_enabled || interp_enabled) { + if (dht_enabled || interp_enabled || warmup_enabled) { /* write results to DHT */ dht_fill_start = MPI_Wtime(); dht->fillDHT(control_enabled ? s_curr_wp_control : s_curr_wp); dht_fill_end = MPI_Wtime(); - if (interp_enabled) { + int filled_count = std::count(dht->getDHTResults().filledDHT.begin(), + dht->getDHTResults().filledDHT.end(), true); + + std::cout << "[Worker " << std::to_string(this->comm_rank) + << "] DHT filled entries=" << std::to_string(filled_count) + << std::endl; + + if (interp_enabled || warmup_enabled) { interp->writePairs(); + std::cout << "[Worker " << std::to_string(this->comm_rank) << "] " + << "Writing pairs to PHT after iteration " + << std::to_string(iteration) << std::endl; } timings.dht_fill += dht_fill_end - dht_fill_start; } diff --git a/src/Control/ControlModule.cpp b/src/Control/ControlModule.cpp index ab51a9ace..a215c5428 100644 --- a/src/Control/ControlModule.cpp +++ b/src/Control/ControlModule.cpp @@ -4,22 +4,62 @@ #include "IO/StatsIO.hpp" #include -void poet::ControlModule::updateControlIteration(const uint32_t iter) { +void poet::ControlModule::UpdateControlIteration(const uint32_t &iter, + const bool &dht_enabled, + const bool &interp_enabled) { - global_iteration = iter; + /* dht_enabled and inter_enabled are user settings set before startig the + * simulation*/ if (control_interval == 0) { control_interval_enabled = false; return; } + // InitiateWarmupPhase(dht_enabled, interp_enabled); + global_iteration = iter; + + if (global_iteration <= control_interval) { + chem->SetWarmupEnabled(true); + chem->SetDhtEnabled(false); + chem->SetInterpEnabled(false); + MSG("Warmup enabled until first control interval at iteration " + + std::to_string(control_interval) + "."); + } else { + chem->SetWarmupEnabled(false); + chem->SetDhtEnabled(true); + chem->SetInterpEnabled(true); + } + + control_interval_enabled = + (control_interval > 0 && iter % control_interval == 0); - control_interval_enabled = (iter % control_interval == 0); if (control_interval_enabled) { MSG("[Control] Control interval enabled at iteration " + std::to_string(iter)); } } +void poet::ControlModule::InitiateWarmupPhase(bool dht_enabled, + bool interp_enabled) { + + // user requested DHT/INTEP? keep them disabled but enable warmup-phase so + // workers do prepareKeys/fillDHT/writePairs as required. + if (global_iteration < control_interval) { + /* warmup phase: keep dht and interp disabled, + workers do prepareKeys/fillDHT/writePairs*/ + chem->SetWarmupEnabled(true); + // chem->SetDhtEnabled(false); + // chem->SetInterpEnabled(false); + MSG("Warmup enabled until first control interval at iteration " + + std::to_string(control_interval) + "."); + } else { + /* after warmup phase: restore according to user's request*/ + chem->SetWarmupEnabled(false); + // chem->SetDhtEnabled(dht_enabled); + // chem->SetInterpEnabled(interp_enabled); + } +} + /* void poet::ControlModule::beginIteration() { if (rollback_enabled) { @@ -33,35 +73,33 @@ void poet::ControlModule::beginIteration() { } */ -void poet::ControlModule::endIteration(const uint32_t iter) { +void poet::ControlModule::EndIteration(const uint32_t iter) { if (!control_interval_enabled) { return; } /* Writing a checkpointing */ /* Control Logic*/ - if (control_interval_enabled && - checkpoint_interval > 0 /*&& !rollback_enabled*/) { - if (!chem) { - MSG("chem pointer is null — skipping checkpoint/stats write"); - } else { - MSG("Writing checkpoint of iteration " + std::to_string(iter)); - write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5", - {.field = chem->getField(), .iteration = iter}); - writeStatsToCSV(error_history, species_names, out_dir, "stats_overview"); + if (!chem) { + MSG("chem pointer is null — skipping checkpoint/stats write"); + } else { + MSG("Writing checkpoint of iteration " + std::to_string(iter)); + write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5", + {.field = chem->getField(), .iteration = iter}); + writeStatsToCSV(error_history, species_names, out_dir, "stats_overview"); - /* + // if() + /* - if (triggerRollbackIfExceeded(*chem, *params, iter)) { - rollback_enabled = true; - rollback_counter++; - sur_disabled_counter = control_interval; - MSG("Interpolation disabled for the next " + - std::to_string(control_interval) + "."); - } - - */ + if (triggerRollbackIfExceeded(*chem, *params, iter)) { + rollback_enabled = true; + rollback_counter++; + sur_disabled_counter = control_interval; + MSG("Interpolation disabled for the next " + + std::to_string(control_interval) + "."); } + + */ } } @@ -75,50 +113,45 @@ void poet::ControlModule::BCastControlFlags() { */ -/* -bool poet::ControlModule::triggerRollbackIfExceeded(ChemistryModule &chem, - RuntimeParameters ¶ms, - uint32_t &iter) { +bool poet::ControlModule::RollbackIfThresholdExceeded(ChemistryModule &chem) { + + /** if (error_history.empty()) { MSG("No error history yet; skipping rollback check."); return false; } - const auto &mape = chem.error_history.back().mape; - const auto &props = chem.getField().GetProps(); + const auto &mape = error_history.back().mape; - for (uint32_t i = 0; i < params.mape_threshold.size(); ++i) { - // Skip invalid entries + for (uint32_t i = 0; i < species_names.size(); ++i) { if (mape[i] == 0) { continue; } - bool mape_exceeded = mape[i] > params.mape_threshold[i]; - if (mape_exceeded) { - uint32_t rollback_iter = ((iter - 1) / params.checkpoint_interval) * - params.checkpoint_interval; + if (mape[i] > mape_threshold[i]) { + uint32_t rollback_iter = ((global_iteration - 1) / checkpoint_interval) * + checkpoint_interval; - MSG("[THRESHOLD EXCEEDED] " + props[i] + + MSG("[THRESHOLD EXCEEDED] " + species_names[i] + " has MAPE = " + std::to_string(mape[i]) + - " exceeding threshold = " + std::to_string(params.mape_threshold[i]) + " exceeding threshold = " + std::to_string(mape_threshold[i]) + " → rolling back to iteration " + std::to_string(rollback_iter)); Checkpoint_s checkpoint_read{.field = chem.getField()}; - read_checkpoint(params.out_dir, + read_checkpoint(out_dir, "checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read); - iter = checkpoint_read.iteration; + global_iteration = checkpoint_read.iteration; return true; } } - MSG("All species are within their MAPE and RRMSE thresholds."); - return - - false; + MSG("All species are within their MAPE thresholds."); + return false; + */ } -*/ -void poet::ControlModule::computeSpeciesErrors( + +void poet::ControlModule::ComputeSpeciesErrorMetrics( const std::vector &reference_values, const std::vector &surrogate_values, const uint32_t size_per_prop) { @@ -142,25 +175,12 @@ void poet::ControlModule::computeSpeciesErrors( return; } - int idxBa = -1, idxCl = -1; - for (size_t k = 0; k < this->species_names.size(); ++k) { - if (this->species_names[k] == "Ba") - idxBa = (int)k; - if (this->species_names[k] == "Cl") - idxCl = (int)k; - } - if (idxBa < 0 || idxCl < 0) { - std::cerr << "[CTRL DIAG] Ba/Cl indices not found: Ba=" << idxBa - << " Cl=" << idxCl << "\n"; - } - for (uint32_t i = 0; i < this->species_names.size(); ++i) { double err_sum = 0.0; double sqr_err_sum = 0.0; uint32_t base_idx = i * size_per_prop; uint32_t nan_count = 0; uint32_t valid_count = 0; - double ref_sum = 0.0, sur_sum = 0.0; for (uint32_t j = 0; j < size_per_prop; ++j) { const double ref_value = reference_values[base_idx + j]; @@ -172,14 +192,9 @@ void poet::ControlModule::computeSpeciesErrors( continue; } valid_count++; - ref_sum += ref_value; - sur_sum += sur_value; if (std::abs(ref_value) < ZERO_ABS) { if (std::abs(sur_value) >= ZERO_ABS) { - std::cerr << "[CTRL TRACE] species=" << this->species_names[i] - << " idx=" << i << " base_idx=" << base_idx << " j=" << j - << " sur_value=" << sur_value << "\n"; err_sum += 1.0; sqr_err_sum += 1.0; } @@ -200,37 +215,6 @@ void poet::ControlModule::computeSpeciesErrors( std::cerr << "[CTRL WARN] no valid samples for species " << i << " (" << this->species_names[i] << "), setting errors to 0\n"; } - /* - // sample printing (keeps previous behavior: species 5 and 6) - if (i == 5 || i == 6) { - std::cerr << "[CTRL SAMPLE] species_index=" << i - << " name=" << this->species_names[i] - << " base_idx=" << base_idx << " nan_count=" << nan_count - << " valid_count=" << valid_count << std::endl; - uint32_t N = std::min(size_per_prop, 20u); - std::cerr << "[CTRL SAMPLE] reference: "; - for (uint32_t j = 0; j < N; ++j) - std::cerr << reference_values[base_idx + j] - << (j + 1 == N ? "\n" : " "); - std::cerr << "[CTRL SAMPLE] surrogate: "; - for (uint32_t j = 0; j < N; ++j) - std::cerr << surrogate_values[base_idx + j] - << (j + 1 == N ? "\n" : " "); - } - */ - - // DEBUG: detailed diagnostics for Ba/Cl (or whichever indices) - if (this->species_names[i] == "Ba" || this->species_names[i] == "Cl") { - double mean_ref = (valid_count > 0) ? (ref_sum / valid_count) : 0.0; - double mean_sur = (valid_count > 0) ? (sur_sum / valid_count) : 0.0; - std::cerr << "[CTRL DIAG] species=" << this->species_names[i] - << " idx=" << i << " base_idx=" << base_idx - << " valid_count=" << valid_count << " nan_count=" << nan_count - << " err_sum=" << err_sum << " sqr_err_sum=" << sqr_err_sum - << " mean_ref=" << mean_ref << " mean_sur=" << mean_sur - << " computed_MAPE=" << species_error_stats.mape[i] - << " computed_RRMSE=" << species_error_stats.rrmse[i] << "\n"; - } } error_history.push_back(species_error_stats); } diff --git a/src/Control/ControlModule.hpp b/src/Control/ControlModule.hpp index ee90fe77a..56be36a95 100644 --- a/src/Control/ControlModule.hpp +++ b/src/Control/ControlModule.hpp @@ -22,20 +22,18 @@ public: // std::uint32_t sur_disabled_counter = 0; // std::uint32_t rollback_counter = 0; - void updateControlIteration(const uint32_t iter); + void UpdateControlIteration(const uint32_t &iter, const bool &dht_enabled, + const bool &interp_enaled); + + void InitiateWarmupPhase(bool dht_enabled, bool interp_enabled); auto GetGlobalIteration() const noexcept { return global_iteration; } // void beginIteration(); - void endIteration(const uint32_t iter); + // void BCastControlFlags(); - void setChemistryModule(poet::ChemistryModule *c) { chem = c; } - - // void BCastControlFlags(); - - //bool triggerRollbackIfExceeded(ChemistryModule &chem, - // RuntimeParameters ¶ms, uint32_t &iter); + bool RollbackIfThresholdExceeded(ChemistryModule &chem); struct SimulationErrorStats { std::vector mape; @@ -43,14 +41,15 @@ public: uint32_t iteration; // iterations in simulation after rollbacks uint32_t rollback_count; - SimulationErrorStats(uint32_t species_count, uint32_t iter, uint32_t counter) + SimulationErrorStats(uint32_t species_count, uint32_t iter, + uint32_t counter) : mape(species_count, 0.0), rrmse(species_count, 0.0), iteration(iter), rollback_count(counter) {} }; - void computeSpeciesErrors(const std::vector &reference_values, - const std::vector &surrogate_values, - const uint32_t size_per_prop); + void ComputeSpeciesErrorMetrics(const std::vector &reference_values, + const std::vector &surrogate_values, + const uint32_t size_per_prop); std::vector error_history; @@ -62,7 +61,7 @@ public: std::vector mape_threshold; }; - void enableControlLogic(const ControlSetup &setup) { + void EnableControlLogic(const ControlSetup &setup) { this->out_dir = setup.out_dir; this->checkpoint_interval = setup.checkpoint_interval; this->control_interval = setup.control_interval; @@ -74,6 +73,10 @@ public: return this->control_interval_enabled; } + void EndIteration(const uint32_t iter); + + void SetChemistryModule(poet::ChemistryModule *c) { chem = c; } + auto GetControlInterval() const { return this->control_interval; } std::vector GetMapeThreshold() const { return this->mape_threshold; } diff --git a/src/poet.cpp b/src/poet.cpp index 94fc72d51..6719f1390 100644 --- a/src/poet.cpp +++ b/src/poet.cpp @@ -315,7 +315,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, //control.beginIteration(iter); // params.global_iter = iter; - control.updateControlIteration(iter); + control.UpdateControlIteration(iter, params.use_dht, params.use_interp); // params.control_interval_enabled = (iter % params.control_interval == 0); double start_t = MPI_Wtime(); @@ -428,7 +428,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, MSG("End of *coupling* iteration " + std::to_string(iter) + "/" + std::to_string(maxiter)); - control.endIteration(iter); + control.EndIteration(iter); /* if (iter % params.checkpoint_interval == 0) { MSG("Writing checkpoint of iteration " + std::to_string(iter)); @@ -650,8 +650,8 @@ int main(int argc, char *argv[]) { init_list.getChemistryInit(), MPI_COMM_WORLD); ControlModule control; - chemistry.setControlModule(&control); - control.setChemistryModule(&chemistry); + chemistry.SetControlModule(&control); + control.SetChemistryModule(&chemistry); const ChemistryModule::SurrogateSetup surr_setup = { getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), @@ -676,7 +676,7 @@ int main(int argc, char *argv[]) { getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), run_params.mape_threshold}; - control.enableControlLogic(ctrl_setup); + control.EnableControlLogic(ctrl_setup); if (MY_RANK > 0) { chemistry.WorkerLoop();