From 1ef61cad333d0afc3efc7272deaac090212e1cd0 Mon Sep 17 00:00:00 2001 From: rastogi Date: Mon, 17 Nov 2025 19:25:02 +0100 Subject: [PATCH] Separate packing of work package outputs into MPI buffer and add worker-side control flush logic --- src/Chemistry/ChemistryModule.hpp | 49 +++++-- src/Chemistry/MasterFunctions.cpp | 50 +++++-- src/Chemistry/WorkerFunctions.cpp | 120 +++++++++------ src/Control/ControlModule.cpp | 234 ++++++++++++++++-------------- src/Control/ControlModule.hpp | 129 ++++++++-------- src/IO/StatsIO.cpp | 2 +- src/IO/StatsIO.hpp | 2 +- src/poet.cpp | 51 +++---- src/poet.hpp.in | 1 + 9 files changed, 359 insertions(+), 279 deletions(-) diff --git a/src/Chemistry/ChemistryModule.hpp b/src/Chemistry/ChemistryModule.hpp index dec7b2bdc..854d488a2 100644 --- a/src/Chemistry/ChemistryModule.hpp +++ b/src/Chemistry/ChemistryModule.hpp @@ -102,6 +102,8 @@ public: this->base_totals = setup.base_totals; + this->ctr_file_out_dir = setup.dht_out_dir; + if (this->dht_enabled || this->interp_enabled) { this->initializeDHT(setup.dht_size_mb, this->params.dht_species, setup.has_het_ids); @@ -257,7 +259,7 @@ public: std::vector ai_surrogate_validity_vector; - void SetControlModule(poet::ControlModule *ctrl) { control_module = ctrl; } + void SetControlModule(poet::ControlModule *ctrl) { control = ctrl; } void SetDhtEnabled(bool enabled) { dht_enabled = enabled; } bool GetDhtEnabled() const { return dht_enabled; } @@ -265,7 +267,7 @@ public: void SetInterpEnabled(bool enabled) { interp_enabled = enabled; } bool GetInterpEnabled() const { return interp_enabled; } - void SetWarmupEnabled(bool enabled) { warmup_enabled = enabled; } + void SetStabEnabled(bool enabled) { stab_enabled = enabled; } protected: void initializeDHT(uint32_t size_mb, @@ -280,13 +282,13 @@ protected: enum { CHEM_FIELD_INIT, - //CHEM_DHT_ENABLE, + CHEM_DHT_ENABLE, + CHEM_IP_ENABLE, + CHEM_CTRL_ENABLE, + CHEM_CTRL_FLAGS, CHEM_DHT_SIGNIF_VEC, CHEM_DHT_SNAPS, CHEM_DHT_READ_FILE, - //CHEM_WARMUP_PHASE, // Control flag - //CHEM_CTRL_ENABLE, // Control flag - //CHEM_IP_ENABLE, CHEM_IP_MIN_ENTRIES, CHEM_IP_SIGNIF_VEC, CHEM_WORK_LOOP, @@ -295,6 +297,9 @@ protected: CHEM_AI_BCAST_VALIDITY }; + /* broadcasted only every control iteration */ + enum { DHT_ENABLE = 1u << 0, IP_ENABLE = 1u << 1, STAB_ENABLE = 1u << 2 }; + enum { LOOP_WORK, LOOP_END, LOOP_CTRL }; enum { @@ -378,9 +383,11 @@ protected: void BCastStringVec(std::vector &io); - int packResultsIntoBuffer(std::vector &mpi_buffer, int base_count, - const WorkPackage &wp, - const WorkPackage &wp_control); + void copyPkgs(const WorkPackage &wp, std::vector &mpi_buffer, + std::size_t offset = 0); + + void copyCtrlPkgs(const WorkPackage &pqc_wp, const WorkPackage &surr_wp, + std::vector &mpi_bufffer, int &count); int comm_size, comm_rank; MPI_Comm group_comm; @@ -400,7 +407,7 @@ protected: bool ai_surrogate_enabled{false}; - static constexpr uint32_t BUFFER_OFFSET = 6; + static constexpr uint32_t BUFFER_OFFSET = 5; inline void ChemBCast(void *buf, int count, MPI_Datatype datatype) const { MPI_Bcast(buf, count, datatype, 0, this->group_comm); @@ -410,6 +417,22 @@ protected: ChemBCast(&type, 1, MPI_INT); } + std::string ctr_file_out_dir; + + inline int buildControlPacket(bool dht, bool interp, bool stab) { + int flags = 0; + + if (dht) + flags |= DHT_ENABLE; + if (interp) + flags |= IP_ENABLE; + if (stab) + flags |= STAB_ENABLE; + return flags; + } + + inline bool hasFlag(int flags, int type) { return (flags & type) != 0; } + double simtime = 0.; double idle_t = 0.; double seq_t = 0.; @@ -437,14 +460,12 @@ protected: std::unique_ptr pqc_runner; - poet::ControlModule *control_module = nullptr; + poet::ControlModule *control = nullptr; std::vector mpi_surr_buffer; bool control_enabled{false}; - bool warmup_enabled{false}; - - // std::vector sur_shuffled; + bool stab_enabled{false}; }; } // namespace poet diff --git a/src/Chemistry/MasterFunctions.cpp b/src/Chemistry/MasterFunctions.cpp index b731188ce..dd381e486 100644 --- a/src/Chemistry/MasterFunctions.cpp +++ b/src/Chemistry/MasterFunctions.cpp @@ -280,10 +280,11 @@ inline void poet::ChemistryModule::MasterSendPkgs( // current work package start location in field send_buffer[end_of_wp + 4] = wp_start_index; // control flags (bitmask) - int flags = (this->interp_enabled ? 1 : 0) | (this->dht_enabled ? 2 : 0) | - (this->warmup_enabled ? 4 : 0) | - (this->control_enabled ? 8 : 0); - send_buffer[end_of_wp + 5] = static_cast(flags); + + /* int flags = (this->interp_enabled ? 1 : 0) | (this->dht_enabled ? 2 : + 0) | (this->warmup_enabled ? 4 : 0) | (this->control_enabled ? 8 : 0); + send_buffer[end_of_wp + 5] = static_cast(flags); + */ /* ATTENTION Worker p has rank p+1 */ // MPI_Send(send_buffer, end_of_wp + BUFFER_OFFSET, MPI_DOUBLE, p + 1, @@ -445,16 +446,29 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { MPI_INT); } + /* broadcast control state once every iteration */ + ftype = CHEM_CTRL_ENABLE; + PropagateFunctionType(ftype); + int ctrl = + (this->control_enabled = this->control->getControlIntervalEnabled()) ? 1 + : 0; + ChemBCast(&ctrl, 1, MPI_INT); + + if (control->shouldBcastFlags()) { + int ftype = CHEM_CTRL_FLAGS; + PropagateFunctionType(ftype); + uint32_t ctrl_flags = buildControlPacket( + this->dht_enabled, this->interp_enabled, this->stab_enabled); + ChemBCast(&ctrl_flags, 1, MPI_INT); + + this->mpi_surr_buffer.assign(this->n_cells * this->prop_count, 0.0); + } + ftype = CHEM_WORK_LOOP; PropagateFunctionType(ftype); MPI_Barrier(this->group_comm); - this->control_enabled = this->control_module->getControlIntervalEnabled(); - if (this->control_enabled) { - this->mpi_surr_buffer.assign(this->n_cells * this->prop_count, 0.0); - } - static uint32_t iteration = 0; /* start time measurement of sequential part */ @@ -466,7 +480,10 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count, wp_sizes_vector.size()); - //this->mpi_surr_buffer.resize(mpi_buffer.size()); + // Only resize surrogate buffer if control is enabled + if (this->control_enabled) { + this->mpi_surr_buffer.resize(mpi_buffer.size()); + } /* setup local variables */ pkg_to_send = wp_sizes_vector.size(); @@ -541,9 +558,10 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { } metrics_a = MPI_Wtime(); - control_module->computeSpeciesErrorMetrics(out_vec, sur_unshuffled, - this->n_cells); + control->computeErrorMetrics(out_vec, sur_unshuffled, this->n_cells, + prop_names); metrics_b = MPI_Wtime(); + this->metrics_t += metrics_b - metrics_a; } @@ -556,9 +574,15 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { /* end time measurement of whole chemistry simulation */ + std::optional target = std::nullopt; + if (this->control_enabled) { + target = control->getRollbackTarget(prop_names); + } + int flush = (this->control_enabled && target.has_value()) ? 1 : 0; + /* advise workers to end chemistry iteration */ for (int i = 1; i < this->comm_size; i++) { - MPI_Send(NULL, 0, MPI_DOUBLE, i, LOOP_END, this->group_comm); + MPI_Send(&flush, 1, MPI_INT, i, LOOP_END, this->group_comm); } this->simtime += dt; diff --git a/src/Chemistry/WorkerFunctions.cpp b/src/Chemistry/WorkerFunctions.cpp index 7a5b9ad84..8975f6f6c 100644 --- a/src/Chemistry/WorkerFunctions.cpp +++ b/src/Chemistry/WorkerFunctions.cpp @@ -82,14 +82,22 @@ void poet::ChemistryModule::WorkerLoop() { std::cout << "Interp_enabled is " << this->interp_enabled << std::endl; break; } + */ case CHEM_CTRL_ENABLE: { - int control_flag = 0; - ChemBCast(&control_flag, 1, MPI_INT); - this->control_enabled = (control_flag == 1); - std::cout << "Control_enabled is " << this->control_enabled << std::endl; + int ctrl = 0; + ChemBCast(&ctrl, 1, MPI_INT); + this->control_enabled = (ctrl == 1); break; } - */ + case CHEM_CTRL_FLAGS: { + int flags = 0; + ChemBCast(&flags, 1, MPI_INT); + this->dht_enabled = hasFlag(flags, DHT_ENABLE); + this->interp_enabled = hasFlag(flags, IP_ENABLE); + this->stab_enabled = hasFlag(flags, STAB_ENABLE); + break; + } + case CHEM_WORK_LOOP: { WorkerProcessPkgs(timings, iteration); break; @@ -146,6 +154,38 @@ void poet::ChemistryModule::WorkerProcessPkgs(struct worker_s &timings, } } } +void poet::ChemistryModule::copyPkgs(const WorkPackage &wp, + std::vector &mpi_buffer, + std::size_t offset) { + for (std::size_t wp_i = 0; wp_i < wp.size; wp_i++) { + std::copy(wp.output[wp_i].begin(), wp.output[wp_i].end(), + mpi_buffer.begin() + offset + this->prop_count * wp_i); + } +} +void poet::ChemistryModule::copyCtrlPkgs(const WorkPackage &pqc_wp, + const WorkPackage &surr_wp, + std::vector &mpi_buffer, + int &count) { + std::size_t wp_offset = surr_wp.size * this->prop_count; + mpi_buffer.resize(count + wp_offset); + + copyPkgs(pqc_wp, mpi_buffer); + + // s_curr_wp only contains the interpolated data + // copy surrogate output after the the pqc output, mpi_buffer[pqc][interp] + + for (std::size_t wp_i = 0; wp_i < surr_wp.size; wp_i++) { + + if (surr_wp.mapping[wp_i] != CHEM_PQC) { + // only copy if surrogate was used + copyPkgs(surr_wp, mpi_buffer, wp_offset); + } else { + // if pqc was used, copy pqc results again + copyPkgs(pqc_wp, mpi_buffer, wp_offset); + } + } + count += wp_offset; +} void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, int double_count, @@ -190,17 +230,21 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, wp_start_index = mpi_buffer[count + 4]; // read packed control flags + /* flags = static_cast(mpi_buffer[count + 5]); this->interp_enabled = (flags & 1) != 0; this->dht_enabled = (flags & 2) != 0; this->warmup_enabled = (flags & 4) != 0; this->control_enabled = (flags & 8) != 0; + */ - /*std::cout << "warmup_enabled is " << warmup_enabled << ", control_enabled is - " - << control_enabled << ", dht_enabled is " - << dht_enabled << ", interp_enabled is " << interp_enabled - << std::endl;*/ + /* + + std::cout << "warmup_enabled is " << stab_enabled << ", control_enabled is " + << control_enabled << ", dht_enabled is " << dht_enabled + << ", interp_enabled is " << interp_enabled << std::endl; + + */ for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) { s_curr_wp.input[wp_i] = @@ -209,7 +253,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, } // std::cout << this->comm_rank << ":" << counter++ << std::endl; - if (dht_enabled || interp_enabled || warmup_enabled) { + if (dht_enabled || interp_enabled || stab_enabled) { dht->prepareKeys(s_curr_wp.input, dt); } @@ -259,39 +303,11 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, if (control_enabled) { ctrl_start = MPI_Wtime(); - std::size_t sur_wp_offset = s_curr_wp.size * this->prop_count; - - mpi_buffer.resize(count + sur_wp_offset); - - for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) { - std::copy(s_curr_wp_control.output[wp_i].begin(), - s_curr_wp_control.output[wp_i].end(), - mpi_buffer.begin() + this->prop_count * wp_i); - } - - // s_curr_wp only contains the interpolated data - // copy surrogate output after the the pqc output, mpi_buffer[pqc][interp] - - for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) { - // only copy if surrogate was used - if (s_curr_wp.mapping[wp_i] != CHEM_PQC) { - std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(), - mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i); - } else { - // if pqc was used, copy pqc results again - std::copy(s_curr_wp_control.output[wp_i].begin(), - s_curr_wp_control.output[wp_i].end(), - mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i); - } - } - count += sur_wp_offset; + copyCtrlPkgs(s_curr_wp_control, s_curr_wp, mpi_buffer, count); ctrl_end = MPI_Wtime(); timings.ctrl_t += ctrl_end - ctrl_start; } else { - for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) { - std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(), - mpi_buffer.begin() + this->prop_count * wp_i); - } + copyPkgs(s_curr_wp, mpi_buffer); } /* send results to master */ @@ -301,13 +317,13 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, MPI_Isend(mpi_buffer.data(), count, MPI_DOUBLE, 0, mpi_tag, MPI_COMM_WORLD, &send_req); - if (dht_enabled || interp_enabled || warmup_enabled) { + if (dht_enabled || interp_enabled || stab_enabled) { /* write results to DHT */ dht_fill_start = MPI_Wtime(); dht->fillDHT(control_enabled ? s_curr_wp_control : s_curr_wp); dht_fill_end = MPI_Wtime(); - if (interp_enabled || warmup_enabled) { + if (interp_enabled || stab_enabled) { interp->writePairs(); } timings.dht_fill += dht_fill_end - dht_fill_start; @@ -317,10 +333,20 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, MPI_Wait(&send_req, MPI_STATUS_IGNORE); } -void poet::ChemistryModule::WorkerPostIter(MPI_Status &prope_status, +void poet::ChemistryModule::WorkerPostIter(MPI_Status &probe_status, uint32_t iteration) { - MPI_Recv(NULL, 0, MPI_DOUBLE, 0, LOOP_END, this->group_comm, - MPI_STATUS_IGNORE); + + int size, flush = 0; + + MPI_Get_count(&probe_status, MPI_INT, &size); + + if (size == 1) { + MPI_Recv(&flush, size, MPI_INT, probe_status.MPI_SOURCE, LOOP_END, + this->group_comm, MPI_STATUS_IGNORE); + } else { + MPI_Recv(NULL, 0, MPI_INT, probe_status.MPI_SOURCE, LOOP_END, + this->group_comm, MPI_STATUS_IGNORE); + } if (this->dht_enabled) { dht_hits.push_back(dht->getHits()); @@ -346,7 +372,7 @@ void poet::ChemistryModule::WorkerPostIter(MPI_Status &prope_status, const auto max_mean_idx = DHT_get_used_idx_factor(this->interp->getDHTObject(), 1); - if (max_mean_idx >= 2) { + if (max_mean_idx >= 2 || flush) { DHT_flush(this->interp->getDHTObject()); DHT_flush(this->dht->getDHT()); if (this->comm_rank == 2) { diff --git a/src/Control/ControlModule.cpp b/src/Control/ControlModule.cpp index 0c3b9ccd8..2829aab59 100644 --- a/src/Control/ControlModule.cpp +++ b/src/Control/ControlModule.cpp @@ -4,178 +4,166 @@ #include "IO/StatsIO.hpp" #include -void poet::ControlModule::updateControlIteration(const uint32_t &iter, - const bool &dht_enabled, - const bool &interp_enabled) { +poet::ControlModule::ControlModule(const ControlConfig &config_, + ChemistryModule *chem_) + : config(config_), chem(chem_) { + assert(chem && "ChemistryModule pointer must not be null"); +} + +void poet::ControlModule::beginIteration(const uint32_t &iter, + const bool &dht_enabled, + const bool &interp_enabled) { /* dht_enabled and inter_enabled are user settings set before startig the * simulation*/ double prep_a, prep_b; prep_a = MPI_Wtime(); - if (control_interval == 0) { + if (config.control_interval == 0) { control_interval_enabled = false; return; } global_iteration = iter; - initiateWarmupPhase(dht_enabled, interp_enabled); + + updateStabilizationPhase(dht_enabled, interp_enabled); control_interval_enabled = - (control_interval > 0 && iter % control_interval == 0); + (config.control_interval > 0 && (iter % config.control_interval == 0)); - if (control_interval_enabled) { - MSG("[Control] Control interval enabled at iteration " + - std::to_string(iter)); - } prep_b = MPI_Wtime(); this->prep_t += prep_b - prep_a; } -void poet::ControlModule::initiateWarmupPhase(bool dht_enabled, - bool interp_enabled) { - +void poet::ControlModule::updateStabilizationPhase(bool dht_enabled, + bool interp_enabled) { + if (rollback_enabled) { + if (disable_surr_counter > 0) { + --disable_surr_counter; + MSG("Rollback counter: " + std::to_string(disable_surr_counter)); + } else { + rollback_enabled = false; + } + } // user requested DHT/INTEP? keep them disabled but enable warmup-phase so - if (global_iteration <= control_interval || rollback_enabled) { - chem->SetWarmupEnabled(true); + if (global_iteration <= config.control_interval || rollback_enabled) { + chem->SetStabEnabled(true); chem->SetDhtEnabled(false); chem->SetInterpEnabled(false); - MSG("Warmup enabled until next control interval at iteration " + - std::to_string(control_interval) + "."); - - if (rollback_enabled) { - if (sur_disabled_counter > 0) { - --sur_disabled_counter; - MSG("Rollback counter: " + std::to_string(sur_disabled_counter)); - } else { - rollback_enabled = false; - } - } return; } - - chem->SetWarmupEnabled(false); + chem->SetStabEnabled(false); chem->SetDhtEnabled(dht_enabled); chem->SetInterpEnabled(interp_enabled); } -void poet::ControlModule::applyControlLogic(ChemistryModule &chem, - uint32_t &iter) { - if (!control_interval_enabled) { - return; - } - writeCheckpointAndMetrics(chem, iter); - - if (checkAndRollback(chem, iter) && rollback_count < 3) { - rollback_enabled = true; - rollback_count++; - sur_disabled_counter = control_interval; - - MSG("Interpolation disabled for the next " + - std::to_string(control_interval) + "."); - - } - -} - -void poet::ControlModule::writeCheckpointAndMetrics(ChemistryModule &chem, - uint32_t iter) { - - double w_check_a, w_check_b, stats_a, stats_b; - MSG("Writing checkpoint of iteration " + std::to_string(iter)); - +void poet::ControlModule::writeCheckpoint(uint32_t &iter, + const std::string &out_dir) { + double w_check_a, w_check_b; w_check_a = MPI_Wtime(); write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5", - {.field = chem.getField(), .iteration = iter}); + {.field = chem->getField(), .iteration = iter}); w_check_b = MPI_Wtime(); this->w_check_t += w_check_b - w_check_a; + last_checkpoint_written = iter; +} + +void poet::ControlModule::readCheckpoint(uint32_t ¤t_iter, + uint32_t rollback_iter, + const std::string &out_dir) { + double r_check_a, r_check_b; + r_check_a = MPI_Wtime(); + Checkpoint_s checkpoint_read{.field = chem->getField()}; + read_checkpoint(out_dir, + "checkpoint" + std::to_string(rollback_iter) + ".hdf5", + checkpoint_read); + current_iter = checkpoint_read.iteration; + r_check_b = MPI_Wtime(); + r_check_t += r_check_b - r_check_a; +} + +void poet::ControlModule::writeErrorMetrics( + const std::string &out_dir, const std::vector &species) { + double stats_a, stats_b; + stats_a = MPI_Wtime(); - writeStatsToCSV(metricsHistory, species_names, out_dir, "stats_overview"); + writeStatsToCSV(metrics_history, species, out_dir, "metrics_overview"); stats_b = MPI_Wtime(); this->stats_t += stats_b - stats_a; } -bool poet::ControlModule::checkAndRollback(ChemistryModule &chem, - uint32_t &iter) { +uint32_t poet::ControlModule::getRollbackIter() { + + uint32_t last_iter = ((global_iteration - 1) / config.checkpoint_interval) * + config.checkpoint_interval; + + uint32_t rollback_iter = (last_iter <= last_checkpoint_written) + ? last_iter + : last_checkpoint_written; + return rollback_iter; +} + +std::optional poet::ControlModule::getRollbackTarget( + const std::vector &species) { double r_check_a, r_check_b; - if (metricsHistory.empty()) { - MSG("No error history yet; skipping rollback check."); - return false; + if (metrics_history.empty()) { + MSG("No error history yet, skipping rollback check."); + flush_request = false; + return std::nullopt; + } + if (rollback_count > 3) { + MSG("Rollback limit reached, skipping rollback."); + flush_request = false; + return std::nullopt; } - const auto &mape = metricsHistory.back().mape; - - for (uint32_t i = 0; i < species_names.size(); ++i) { - if (mape[i] == 0) { + const auto &mape = metrics_history.back().mape; + for (uint32_t i = 0; i < species.size(); ++i) { + // skip Charge + if (i == 4 || mape[i] == 0) { continue; } + if (mape[i] > config.mape_threshold[i]) { + if (last_checkpoint_written == 0) { + MSG(" Threshold exceeded but no checkpoint exists yet."); + return std::nullopt; + } - if (mape[i] > mape_threshold[i]) { - uint32_t rollback_iter = - ((iter - 1) / checkpoint_interval) * checkpoint_interval; + flush_request = true; - MSG("[THRESHOLD EXCEEDED] " + species_names[i] + + MSG("T hreshold exceeded " + species[i] + " has MAPE = " + std::to_string(mape[i]) + - " exceeding threshold = " + std::to_string(mape_threshold[i]) + - " → rolling back to iteration " + std::to_string(rollback_iter)); - - r_check_a = MPI_Wtime(); - Checkpoint_s checkpoint_read{.field = chem.getField()}; - read_checkpoint(out_dir, - "checkpoint" + std::to_string(rollback_iter) + ".hdf5", - checkpoint_read); - iter = checkpoint_read.iteration; - r_check_b = MPI_Wtime(); - r_check_t += r_check_b - r_check_a; - return true; + " exceeding threshold = " + std::to_string(config.mape_threshold[i])); + return getRollbackIter(); } } MSG("All species are within their MAPE thresholds."); - - return false; + flush_request = false; + return std::nullopt; } -void poet::ControlModule::computeSpeciesErrorMetrics( +void poet::ControlModule::computeErrorMetrics( const std::vector &reference_values, - const std::vector &surrogate_values, const uint32_t size_per_prop) { + const std::vector &surrogate_values, const uint32_t size_per_prop, + const std::vector &species) { - SpeciesErrorMetrics metrics(this->species_names.size(), global_iteration, - rollback_count); + SpeciesErrorMetrics metrics(species.size(), global_iteration, rollback_count); - if (reference_values.size() != surrogate_values.size()) { - MSG(" Reference and surrogate vectors differ in size: " + - std::to_string(reference_values.size()) + " vs " + - std::to_string(surrogate_values.size())); - return; - } - - const std::size_t expected = - static_cast(this->species_names.size()) * size_per_prop; - if (reference_values.size() < expected) { - std::cerr << "[CTRL ERROR] input vectors too small: expected >= " - << expected << " entries, got " << reference_values.size() - << "\n"; - return; - } - - for (uint32_t i = 0; i < this->species_names.size(); ++i) { + for (uint32_t i = 0; i < species.size(); ++i) { double err_sum = 0.0; double sqr_err_sum = 0.0; uint32_t base_idx = i * size_per_prop; - int count = 0; - for (uint32_t j = 0; j < size_per_prop; ++j) { const double ref_value = reference_values[base_idx + j]; const double sur_value = surrogate_values[base_idx + j]; - const double ZERO_ABS = 1e-13; + const double ZERO_ABS = config.zero_abs; if (std::isnan(ref_value) || std::isnan(sur_value)) { continue; } - if (std::abs(ref_value) < ZERO_ABS) { if (std::abs(sur_value) >= ZERO_ABS) { err_sum += 1.0; @@ -192,5 +180,35 @@ void poet::ControlModule::computeSpeciesErrorMetrics( metrics.mape[i] = 100.0 * (err_sum / size_per_prop); metrics.rrmse[i] = std::sqrt(sqr_err_sum / size_per_prop); } - metricsHistory.push_back(metrics); + metrics_history.push_back(metrics); +} + +void poet::ControlModule::processCheckpoint( + uint32_t ¤t_iter, const std::string &out_dir, + const std::vector &species) { + + if (!control_interval_enabled) + return; + + if (flush_request && rollback_count < 3) { + uint32_t target = getRollbackIter(); + readCheckpoint(current_iter, target, out_dir); + + rollback_enabled = true; + rollback_count++; + disable_surr_counter = config.control_interval; + + MSG("Restored checkpoint " + std::to_string(target) + + ", surrogates disabled for " + std::to_string(config.control_interval)); + } else { + writeCheckpoint(global_iteration, out_dir); + } +} + +bool poet::ControlModule::shouldBcastFlags() const { + if (global_iteration == 1 || + global_iteration % config.control_interval == 1) { + return true; + } + return false; } diff --git a/src/Control/ControlModule.hpp b/src/Control/ControlModule.hpp index 01c76ccb9..29deec61c 100644 --- a/src/Control/ControlModule.hpp +++ b/src/Control/ControlModule.hpp @@ -4,8 +4,8 @@ #include "Base/Macros.hpp" #include "Chemistry/ChemistryModule.hpp" #include "poet.hpp" - #include +#include #include #include @@ -13,104 +13,93 @@ namespace poet { class ChemistryModule; +struct ControlConfig { + uint32_t control_interval = 0; + uint32_t checkpoint_interval = 0; + double zero_abs = 0.0; + std::vector mape_threshold; +}; + +struct SpeciesErrorMetrics { + std::vector mape; + std::vector rrmse; + uint32_t iteration = 0; + uint32_t rollback_count = 0; + + SpeciesErrorMetrics(uint32_t n_species, uint32_t iter, uint32_t rb_count) + : mape(n_species, 0.0), rrmse(n_species, 0.0), iteration(iter), + rollback_count(rb_count) {} +}; + class ControlModule { - public: - /* Control configuration*/ + explicit ControlModule(const ControlConfig &config, ChemistryModule *chem); - // std::uint32_t global_iter = 0; - // std::uint32_t sur_disabled_counter = 0; - // std::uint32_t rollback_counter = 0; + void beginIteration(const uint32_t &iter, const bool &dht_enabled, + const bool &interp_enaled); - void updateControlIteration(const uint32_t &iter, const bool &dht_enabled, - const bool &interp_enaled); + void writeCheckpoint(uint32_t &iter, const std::string &out_dir); - void initiateWarmupPhase(bool dht_enabled, bool interp_enabled); + void writeErrorMetrics(const std::string &out_dir, + const std::vector &species); - bool checkAndRollback(ChemistryModule &chem, uint32_t &iter); + std::optional getRollbackTarget(); - struct SpeciesErrorMetrics { - std::vector mape; - std::vector rrmse; - uint32_t iteration; // iterations in simulation after rollbacks - uint32_t rollback_count; + void computeErrorMetrics(const std::vector &reference_values, + const std::vector &surrogate_values, + const uint32_t size_per_prop, + const std::vector &species); - SpeciesErrorMetrics(uint32_t species_count, uint32_t iter, uint32_t counter) - : mape(species_count, 0.0), rrmse(species_count, 0.0), iteration(iter), - rollback_count(counter) {} - }; + void processCheckpoint(uint32_t ¤t_iter, + const std::string &out_dir, + const std::vector &species); - void computeSpeciesErrorMetrics(const std::vector &reference_values, - const std::vector &surrogate_values, - const uint32_t size_per_prop); - - std::vector metricsHistory; - - struct ControlSetup { - std::string out_dir; - std::uint32_t checkpoint_interval; - std::uint32_t control_interval; - std::vector species_names; - std::vector mape_threshold; - }; - - void enableControlLogic(const ControlSetup &setup) { - this->out_dir = setup.out_dir; - this->checkpoint_interval = setup.checkpoint_interval; - this->control_interval = setup.control_interval; - this->species_names = setup.species_names; - this->mape_threshold = setup.mape_threshold; - } + std::optional + getRollbackTarget(const std::vector &species); + bool shouldBcastFlags() const; bool getControlIntervalEnabled() const { return this->control_interval_enabled; } - void applyControlLogic(ChemistryModule &chem, uint32_t &iter); - - void writeCheckpointAndMetrics(ChemistryModule &chem, uint32_t iter); - - auto getGlobalIteration() const noexcept { return global_iteration; } - - void setChemistryModule(poet::ChemistryModule *c) { chem = c; } - - auto getControlInterval() const { return this->control_interval; } - - std::vector getMapeThreshold() const { return this->mape_threshold; } + bool getFlushRequest() const { return flush_request; } + void clearFlushRequest() { flush_request = false; } /* Profiling getters */ - auto getUpdateCtrlLogicTime() const { return this->prep_t; } - - auto getWriteCheckpointTime() const { return this->w_check_t; } - - auto getReadCheckpointTime() const { return this->r_check_t; } - - auto getWriteMetricsTime() const { return this->stats_t; } + double getUpdateCtrlLogicTime() const { return prep_t; } + double getWriteCheckpointTime() const { return w_check_t; } + double getReadCheckpointTime() const { return r_check_t; } + double getWriteMetricsTime() const { return stats_t; } private: - bool rollback_enabled = false; - bool control_interval_enabled = false; + void updateStabilizationPhase(bool dht_enabled, bool interp_enabled); - poet::ChemistryModule *chem = nullptr; + void readCheckpoint(uint32_t ¤t_iter, + uint32_t rollback_iter, const std::string &out_dir); + + uint32_t getRollbackIter(); + + ControlConfig config; + ChemistryModule *chem = nullptr; - std::uint32_t checkpoint_interval = 0; - std::uint32_t control_interval = 0; std::uint32_t global_iteration = 0; std::uint32_t rollback_count = 0; - std::uint32_t sur_disabled_counter = 0; - std::vector mape_threshold; + std::uint32_t disable_surr_counter = 0; + std::uint32_t last_checkpoint_written = 0; - std::vector species_names; - std::string out_dir; + bool rollback_enabled = false; + bool control_interval_enabled = false; + bool flush_request = false; + + bool bcast_flags = false; + + std::vector metrics_history; double prep_t = 0.; double r_check_t = 0.; double w_check_t = 0.; double stats_t = 0.; - - /* Buffer for shuffled surrogate data */ - std::vector sur_shuffled; }; } // namespace poet diff --git a/src/IO/StatsIO.cpp b/src/IO/StatsIO.cpp index 1b7d58d0c..06af6ae86 100644 --- a/src/IO/StatsIO.cpp +++ b/src/IO/StatsIO.cpp @@ -7,7 +7,7 @@ namespace poet { - void writeStatsToCSV(const std::vector &all_stats, + void writeStatsToCSV(const std::vector &all_stats, const std::vector &species_names, const std::string &out_dir, const std::string &filename) diff --git a/src/IO/StatsIO.hpp b/src/IO/StatsIO.hpp index 5333c4fd8..512ab0d4c 100644 --- a/src/IO/StatsIO.hpp +++ b/src/IO/StatsIO.hpp @@ -3,7 +3,7 @@ namespace poet { - void writeStatsToCSV(const std::vector &all_stats, + void writeStatsToCSV(const std::vector &all_stats, const std::vector &species_names, const std::string &out_dir, const std::string &filename); diff --git a/src/poet.cpp b/src/poet.cpp index 2d4ed4d71..bae9363f6 100644 --- a/src/poet.cpp +++ b/src/poet.cpp @@ -255,6 +255,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) { Rcpp::as(global_rt_setup->operator[]("checkpoint_interval")); params.mape_threshold = Rcpp::as>( global_rt_setup->operator[]("mape_threshold")); + params.zero_abs = Rcpp::as(global_rt_setup->operator[]("zero_abs")); } catch (const std::exception &e) { ERRMSG("Error while parsing R scripts: " + std::string(e.what())); return ParseRet::PARSER_ERROR; @@ -300,7 +301,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, double dSimTime{0}; for (uint32_t iter = 1; iter < maxiter + 1; iter++) { - control.updateControlIteration(iter, params.use_dht, params.use_interp); + control.beginIteration(iter, params.use_dht, params.use_interp); double start_t = MPI_Wtime(); @@ -410,7 +411,10 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, MSG("End of *coupling* iteration " + std::to_string(iter) + "/" + std::to_string(maxiter)); - control.applyControlLogic(chem, iter); + if (control.getControlIntervalEnabled()) { + control.processCheckpoint(iter, params.out_dir, chem.getField().GetProps()); + control.writeErrorMetrics(params.out_dir, chem.getField().GetProps()); + } // MSG(); } // END SIMULATION LOOP @@ -435,16 +439,13 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, ctrl_profiling["write_stats"] = control.getWriteMetricsTime(); ctrl_profiling["ctrl_logic_master"] = control.getUpdateCtrlLogicTime(); ctrl_profiling["recv_data_master"] = chem.GetMasterRecvCtrlDataTime(); - ctrl_profiling["worker"] = - Rcpp::wrap(chem.GetWorkerControlTimings()); - + ctrl_profiling["worker"] = Rcpp::wrap(chem.GetWorkerControlTimings()); - //if (params.use_dht) { - chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits()); - chem_profiling["dht_evictions"] = Rcpp::wrap(chem.GetWorkerDHTEvictions()); - chem_profiling["dht_get_time"] = Rcpp::wrap(chem.GetWorkerDHTGetTimings()); - chem_profiling["dht_fill_time"] = - Rcpp::wrap(chem.GetWorkerDHTFillTimings()); + // if (params.use_dht) { + chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits()); + chem_profiling["dht_evictions"] = Rcpp::wrap(chem.GetWorkerDHTEvictions()); + chem_profiling["dht_get_time"] = Rcpp::wrap(chem.GetWorkerDHTGetTimings()); + chem_profiling["dht_fill_time"] = Rcpp::wrap(chem.GetWorkerDHTFillTimings()); //} if (params.use_interp) { @@ -607,10 +608,10 @@ int main(int argc, char *argv[]) { ChemistryModule chemistry(run_params.work_package_size, init_list.getChemistryInit(), MPI_COMM_WORLD); - - ControlModule control; - chemistry.SetControlModule(&control); - control.setChemistryModule(&chemistry); + + // ControlModule control; + // chemistry.SetControlModule(&control); + // control.setChemistryModule(&chemistry); const ChemistryModule::SurrogateSetup surr_setup = { getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), @@ -628,15 +629,6 @@ int main(int argc, char *argv[]) { chemistry.masterEnableSurrogates(surr_setup); - const ControlModule::ControlSetup ctrl_setup = { - run_params.out_dir, // added - run_params.checkpoint_interval, - run_params.control_interval, - getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), - run_params.mape_threshold}; - - control.enableControlLogic(ctrl_setup); - if (MY_RANK > 0) { chemistry.WorkerLoop(); } else { @@ -680,7 +672,16 @@ int main(int argc, char *argv[]) { chemistry.masterSetField(init_list.getInitialGrid()); - Rcpp::List profiling = RunMasterLoop(R, run_params, diffusion, chemistry, control); + ControlConfig config(run_params.control_interval, + run_params.checkpoint_interval, run_params.zero_abs, + run_params.mape_threshold); + + ControlModule control(config, &chemistry); + + chemistry.SetControlModule(&control); + + Rcpp::List profiling = + RunMasterLoop(R, run_params, diffusion, chemistry, control); MSG("finished simulation loop"); diff --git a/src/poet.hpp.in b/src/poet.hpp.in index b5f807c1c..bd9bd9ce3 100644 --- a/src/poet.hpp.in +++ b/src/poet.hpp.in @@ -54,6 +54,7 @@ struct RuntimeParameters { std::uint32_t checkpoint_interval = 0; std::uint32_t control_interval = 0; std::vector mape_threshold; + double zero_abs = 0.0; static constexpr std::uint32_t WORK_PACKAGE_SIZE_DEFAULT = 32; std::uint32_t work_package_size = WORK_PACKAGE_SIZE_DEFAULT;