diff --git a/bin/dolo_fgcs_3_rt.R b/bin/dolo_fgcs_3_rt.R index 92903ac3f..b35c395d1 100644 --- a/bin/dolo_fgcs_3_rt.R +++ b/bin/dolo_fgcs_3_rt.R @@ -1,21 +1,24 @@ -iterations <- 15000 +iterations <- 10000 dt <- 200 -checkpoint_interval <- 100 -control_interval <- 100 +chkpt_interval <- 100 +ctrl_interval <- 100 mape_threshold <- rep(0.0035, 13) -zero_abs <- 0.0 -#mape_threshold[5] <- 1 #Charge +mape_threshold[5] <- 1 #Charge +zero_abs <- 1e-13 +rb_limit <- 3 + #ctrl_cell_ids <- seq(0, (400*400)/2 - 1, by = 401) #out_save <- seq(500, iterations, by = 500) -#out_save = c(seq(1, 10), seq(10, 100, by= 10), seq(200, iterations, by=100)) +out_save = c(seq(1, 10), seq(10, 100, by= 10), seq(200, iterations, by=100)) list( timesteps = rep(dt, iterations), - store_result = FALSE, - #out_save = out_save, - checkpoint_interval = checkpoint_interval, - control_interval = control_interval, + store_result = TRUE, + out_save = out_save, + chkpt_interval = chkpt_interval, + ctrl_interval = ctrl_interval, mape_threshold = mape_threshold, - zero_abs = zero_abs + zero_abs = zero_abs, + rb_limit = rb_limit ) \ No newline at end of file diff --git a/bin/plot_metrics.R b/bin/plot_metrics.R index 9e351cf24..3375213e2 100644 --- a/bin/plot_metrics.R +++ b/bin/plot_metrics.R @@ -58,7 +58,7 @@ all_data <- lapply(args, function(stats_file) { }) combined_data <- bind_rows(all_data) %>% - filter(Iteration <= 3000) %>% + filter(Iteration >= 3000 & Iteration <= 8000) %>% filter(is.finite(MedianMAPE) & MedianMAPE > 0) %>% filter(is.finite(MaxMAPE) & MaxMAPE > 0) diff --git a/bin/run_poet.sh b/bin/run_poet.sh index e5fe7591c..5f256fbb2 100644 --- a/bin/run_poet.sh +++ b/bin/run_poet.sh @@ -1,7 +1,7 @@ #!/bin/bash -#SBATCH --job-name=proto1_only_interp_zeroabs -#SBATCH --output=proto1_only_interp_zeroabs_%j.out -#SBATCH --error=proto1_only_interp_zeroabs_%j.err +#SBATCH --job-name=p1_eps0035_v2 +#SBATCH --output=p1_eps0035_v2_%j.out +#SBATCH --error=p1_eps0035_v2_%j.err #SBATCH --partition=long #SBATCH --nodes=6 #SBATCH --ntasks-per-node=24 @@ -15,5 +15,5 @@ module purge module load cmake gcc openmpi #mpirun -n 144 ./poet dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_only_pqc -mpirun -n 144 ./poet --interp dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 proto1_only_interp_zeroabs +mpirun -n 144 ./poet --interp dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 p1_eps0035_v2 #mpirun -n 144 ./poet --interp barite_fgcs_4_new/barite_fgcs_4_new_rt.R barite_fgcs_4_new/barite_fgcs_4_new.qs2 barite \ No newline at end of file diff --git a/src/Chemistry/ChemistryModule.hpp b/src/Chemistry/ChemistryModule.hpp index 854d488a2..d0680d2f0 100644 --- a/src/Chemistry/ChemistryModule.hpp +++ b/src/Chemistry/ChemistryModule.hpp @@ -102,8 +102,6 @@ public: this->base_totals = setup.base_totals; - this->ctr_file_out_dir = setup.dht_out_dir; - if (this->dht_enabled || this->interp_enabled) { this->initializeDHT(setup.dht_size_mb, this->params.dht_species, setup.has_het_ids); @@ -269,6 +267,8 @@ public: void SetStabEnabled(bool enabled) { stab_enabled = enabled; } + inline uint32_t buildCtrlFlags(bool dht, bool interp, bool stab); + protected: void initializeDHT(uint32_t size_mb, const NamedVector &key_species, @@ -384,10 +384,10 @@ protected: void BCastStringVec(std::vector &io); void copyPkgs(const WorkPackage &wp, std::vector &mpi_buffer, - std::size_t offset = 0); + std::size_t offset = 0); void copyCtrlPkgs(const WorkPackage &pqc_wp, const WorkPackage &surr_wp, - std::vector &mpi_bufffer, int &count); + std::vector &mpi_bufffer, int &count); int comm_size, comm_rank; MPI_Comm group_comm; @@ -417,20 +417,6 @@ protected: ChemBCast(&type, 1, MPI_INT); } - std::string ctr_file_out_dir; - - inline int buildControlPacket(bool dht, bool interp, bool stab) { - int flags = 0; - - if (dht) - flags |= DHT_ENABLE; - if (interp) - flags |= IP_ENABLE; - if (stab) - flags |= STAB_ENABLE; - return flags; - } - inline bool hasFlag(int flags, int type) { return (flags & type) != 0; } double simtime = 0.; @@ -464,7 +450,7 @@ protected: std::vector mpi_surr_buffer; - bool control_enabled{false}; + bool ctrl_enabled{false}; bool stab_enabled{false}; }; } // namespace poet diff --git a/src/Chemistry/MasterFunctions.cpp b/src/Chemistry/MasterFunctions.cpp index dd381e486..a72419aaa 100644 --- a/src/Chemistry/MasterFunctions.cpp +++ b/src/Chemistry/MasterFunctions.cpp @@ -232,6 +232,17 @@ inline void printProgressbar(int count_pkgs, int n_wp, int barWidth = 70) { /* end visual progress */ } +inline uint32_t poet::ChemistryModule::buildCtrlFlags(bool dht, bool interp, bool stab) { + uint32_t flags = 0; + if (dht) + flags |= DHT_ENABLE; + if (interp) + flags |= IP_ENABLE; + if (stab) + flags |= STAB_ENABLE; + return flags; +} + inline void poet::ChemistryModule::MasterSendPkgs( worker_list_t &w_list, workpointer_t &work_pointer, workpointer_t &sur_pointer, int &pkg_to_send, int &count_pkgs, @@ -257,7 +268,7 @@ inline void poet::ChemistryModule::MasterSendPkgs( /* note current processed work package in workerlist */ w_list[p].send_addr = work_pointer.base(); w_list[p].surrogate_addr = sur_pointer.base(); - // this->control_enabled ? sur_pointer.base() : w_list[p].surrogate_addr = + // this->ctrl_enabled ? sur_pointer.base() : w_list[p].surrogate_addr = // nullptr; /* push work pointer to next work package */ @@ -282,7 +293,7 @@ inline void poet::ChemistryModule::MasterSendPkgs( // control flags (bitmask) /* int flags = (this->interp_enabled ? 1 : 0) | (this->dht_enabled ? 2 : - 0) | (this->warmup_enabled ? 4 : 0) | (this->control_enabled ? 8 : 0); + 0) | (this->warmup_enabled ? 4 : 0) | (this->ctrl_enabled ? 8 : 0); send_buffer[end_of_wp + 5] = static_cast(flags); */ @@ -449,19 +460,17 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { /* broadcast control state once every iteration */ ftype = CHEM_CTRL_ENABLE; PropagateFunctionType(ftype); - int ctrl = - (this->control_enabled = this->control->getControlIntervalEnabled()) ? 1 - : 0; - ChemBCast(&ctrl, 1, MPI_INT); + uint32_t ctrl = (ctrl_enabled = control->isCtrlIntervalActive()) ? 1 : 0; + ChemBCast(&ctrl, 1, MPI_UINT32_T); - if (control->shouldBcastFlags()) { + if (control->needsFlagBcast()) { int ftype = CHEM_CTRL_FLAGS; PropagateFunctionType(ftype); - uint32_t ctrl_flags = buildControlPacket( - this->dht_enabled, this->interp_enabled, this->stab_enabled); - ChemBCast(&ctrl_flags, 1, MPI_INT); + uint32_t ctrl_flags = + buildCtrlFlags(dht_enabled, interp_enabled, stab_enabled); + ChemBCast(&ctrl_flags, 1, MPI_UINT32_T); - this->mpi_surr_buffer.assign(this->n_cells * this->prop_count, 0.0); + mpi_surr_buffer.assign(n_cells * prop_count, 0.0); } ftype = CHEM_WORK_LOOP; @@ -481,8 +490,8 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { wp_sizes_vector.size()); // Only resize surrogate buffer if control is enabled - if (this->control_enabled) { - this->mpi_surr_buffer.resize(mpi_buffer.size()); + if (ctrl_enabled) { + mpi_surr_buffer.resize(mpi_buffer.size()); } /* setup local variables */ @@ -490,9 +499,8 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { pkg_to_recv = wp_sizes_vector.size(); workpointer_t work_pointer = mpi_buffer.begin(); - workpointer_t sur_pointer = this->mpi_surr_buffer.begin(); - //(this->control_enabled ? this->mpi_surr_buffer.begin() - // : mpi_buffer.end()); + workpointer_t sur_pointer = mpi_surr_buffer.begin(); + worker_list_t worker_list(this->comm_size - 1); free_workers = this->comm_size - 1; @@ -540,14 +548,12 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { chem_field = out_vec; /* do master stuff */ - if (this->control_enabled) { - std::cout << "[Master] Control logic enabled for this iteration." - << std::endl; + if (ctrl_enabled) { std::vector sur_unshuffled{mpi_surr_buffer}; shuf_a = MPI_Wtime(); - unshuffleField(this->mpi_surr_buffer, this->n_cells, this->prop_count, - wp_sizes_vector.size(), sur_unshuffled); + unshuffleField(mpi_surr_buffer, n_cells, prop_count, wp_sizes_vector.size(), + sur_unshuffled); shuf_b = MPI_Wtime(); this->shuf_t += shuf_b - shuf_a; @@ -558,8 +564,7 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { } metrics_a = MPI_Wtime(); - control->computeErrorMetrics(out_vec, sur_unshuffled, this->n_cells, - prop_names); + control->computeMetrics(out_vec, sur_unshuffled, n_cells, prop_names); metrics_b = MPI_Wtime(); this->metrics_t += metrics_b - metrics_a; @@ -575,10 +580,10 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { /* end time measurement of whole chemistry simulation */ std::optional target = std::nullopt; - if (this->control_enabled) { - target = control->getRollbackTarget(prop_names); + if (ctrl_enabled) { + target = control->findRbTarget(prop_names); } - int flush = (this->control_enabled && target.has_value()) ? 1 : 0; + int flush = (ctrl_enabled && target.has_value()) ? 1 : 0; /* advise workers to end chemistry iteration */ for (int i = 1; i < this->comm_size; i++) { diff --git a/src/Chemistry/WorkerFunctions.cpp b/src/Chemistry/WorkerFunctions.cpp index 9578f9dde..ecb5d4f30 100644 --- a/src/Chemistry/WorkerFunctions.cpp +++ b/src/Chemistry/WorkerFunctions.cpp @@ -60,17 +60,17 @@ void poet::ChemistryModule::WorkerLoop() { break; } case CHEM_CTRL_ENABLE: { - int ctrl = 0; - ChemBCast(&ctrl, 1, MPI_INT); - this->control_enabled = (ctrl == 1); + uint32_t ctrl = 0; + ChemBCast(&ctrl, 1, MPI_UINT32_T); + ctrl_enabled = (ctrl == 1); break; } case CHEM_CTRL_FLAGS: { - int flags = 0; - ChemBCast(&flags, 1, MPI_INT); - this->dht_enabled = hasFlag(flags, DHT_ENABLE); - this->interp_enabled = hasFlag(flags, IP_ENABLE); - this->stab_enabled = hasFlag(flags, STAB_ENABLE); + uint32_t flags = 0; + ChemBCast(&flags, 1, MPI_UINT32_T); + dht_enabled = hasFlag(flags, DHT_ENABLE); + interp_enabled = hasFlag(flags, IP_ENABLE); + stab_enabled = hasFlag(flags, STAB_ENABLE); break; } case CHEM_WORK_LOOP: { @@ -204,21 +204,10 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, // current work package start location in field wp_start_index = mpi_buffer[count + 4]; - // read packed control flags /* - flags = static_cast(mpi_buffer[count + 5]); - this->interp_enabled = (flags & 1) != 0; - this->dht_enabled = (flags & 2) != 0; - this->warmup_enabled = (flags & 4) != 0; - this->control_enabled = (flags & 8) != 0; - */ - - /* - - std::cout << "warmup_enabled is " << stab_enabled << ", control_enabled is " - << control_enabled << ", dht_enabled is " << dht_enabled + std::cout << "warmup_enabled is " << stab_enabled << ", ctrl_enabled is " + << ctrl_enabled << ", dht_enabled is " << dht_enabled << ", interp_enabled is " << interp_enabled << std::endl; - */ for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) { @@ -258,7 +247,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, poet::WorkPackage s_curr_wp_control = s_curr_wp; - if (control_enabled) { + if (ctrl_enabled) { ctrl_cp_start = MPI_Wtime(); for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) { s_curr_wp_control.output[wp_i] = @@ -271,12 +260,12 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, phreeqc_time_start = MPI_Wtime(); - WorkerRunWorkPackage(control_enabled ? s_curr_wp_control : s_curr_wp, + WorkerRunWorkPackage(ctrl_enabled ? s_curr_wp_control : s_curr_wp, current_sim_time, dt); phreeqc_time_end = MPI_Wtime(); - if (control_enabled) { + if (ctrl_enabled) { ctrl_start = MPI_Wtime(); copyCtrlPkgs(s_curr_wp_control, s_curr_wp, mpi_buffer, count); ctrl_end = MPI_Wtime(); @@ -288,14 +277,14 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, /* send results to master */ MPI_Request send_req; - int mpi_tag = control_enabled ? LOOP_CTRL : LOOP_WORK; + int mpi_tag = ctrl_enabled ? LOOP_CTRL : LOOP_WORK; MPI_Isend(mpi_buffer.data(), count, MPI_DOUBLE, 0, mpi_tag, MPI_COMM_WORLD, &send_req); if (dht_enabled || interp_enabled || stab_enabled) { /* write results to DHT */ dht_fill_start = MPI_Wtime(); - dht->fillDHT(control_enabled ? s_curr_wp_control : s_curr_wp); + dht->fillDHT(ctrl_enabled ? s_curr_wp_control : s_curr_wp); dht_fill_end = MPI_Wtime(); if (interp_enabled || stab_enabled) { diff --git a/src/Control/ControlModule.cpp b/src/Control/ControlModule.cpp index 15cd21c90..ebdae31ae 100644 --- a/src/Control/ControlModule.cpp +++ b/src/Control/ControlModule.cpp @@ -4,60 +4,67 @@ #include "IO/StatsIO.hpp" #include -poet::ControlModule::ControlModule(const ControlConfig &config_, - ChemistryModule *chem_) +poet::ControlModule::ControlModule(const ControlConfig &config_, ChemistryModule *chem_) : config(config_), chem(chem_) { assert(chem && "ChemistryModule pointer must not be null"); } -void poet::ControlModule::beginIteration(const uint32_t &iter, - const bool &dht_enabled, +void poet::ControlModule::beginIteration(const uint32_t &iter, const bool &dht_enabled, const bool &interp_enabled) { - - /* dht_enabled and inter_enabled are user settings set before startig the - * simulation*/ double prep_a, prep_b; prep_a = MPI_Wtime(); - if (config.control_interval == 0) { - control_interval_enabled = false; + if (config.ctrl_interval == 0) { + ctrl_active = false; return; } - global_iteration = iter; + global_iter = iter; - updateStabilizationPhase(dht_enabled, interp_enabled); + updateSurrState(dht_enabled, interp_enabled); - control_interval_enabled = - (config.control_interval > 0 && (iter % config.control_interval == 0)); + ctrl_active = (config.ctrl_interval > 0 && (iter % config.ctrl_interval == 0)); prep_b = MPI_Wtime(); this->prep_t += prep_b - prep_a; } -void poet::ControlModule::updateStabilizationPhase(bool dht_enabled, - bool interp_enabled) { - if (rollback_enabled) { - if (disable_surr_counter > 0) { - --disable_surr_counter; - MSG("Rollback counter: " + std::to_string(disable_surr_counter)); - } else { - rollback_enabled = false; +/* manages the overall surrogate state, by enabling/disabling state based on + * warmup logic and rollback conditions*/ +void poet::ControlModule::updateSurrState(bool dht_enabled, bool interp_enabled) { + + bool in_warmup = (global_iter <= config.ctrl_interval); + bool rb_limit_reached = (rb_count >= config.rb_limit); + + if (rb_enabled && stab_countdown > 0 && !rb_limit_reached) { + --stab_countdown; + std::cout << "Rollback counter: " << stab_countdown << std::endl; + if (stab_countdown == 0) { + rb_enabled = false; } + flush_request = false; } - // user requested DHT/INTEP? keep them disabled but enable warmup-phase so - if (global_iteration <= config.control_interval || rollback_enabled) { - chem->SetStabEnabled(true); + /* disable surrogates during warmup, active rollback or after limit */ + if (in_warmup || rb_enabled || rb_limit_reached) { + chem->SetStabEnabled(!rb_limit_reached); chem->SetDhtEnabled(false); chem->SetInterpEnabled(false); + + if (rb_limit_reached) { + std::cout << "Interpolation completly disabled." << std::endl; + } else { + std::cout << "In stabilization phase." << std::endl; + } + return; } + /* enable user-requested surrogates */ chem->SetStabEnabled(false); chem->SetDhtEnabled(dht_enabled); chem->SetInterpEnabled(interp_enabled); + std::cout << "Interpolating." << std::endl; } -void poet::ControlModule::writeCheckpoint(uint32_t &iter, - const std::string &out_dir) { +void poet::ControlModule::writeCheckpoint(uint32_t &iter, const std::string &out_dir) { double w_check_a, w_check_b; w_check_a = MPI_Wtime(); write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5", @@ -65,25 +72,25 @@ void poet::ControlModule::writeCheckpoint(uint32_t &iter, w_check_b = MPI_Wtime(); this->w_check_t += w_check_b - w_check_a; - last_checkpoint_written = iter; + last_chkpt_written = iter; } -void poet::ControlModule::readCheckpoint(uint32_t ¤t_iter, - uint32_t rollback_iter, +void poet::ControlModule::readCheckpoint(uint32_t ¤t_iter, uint32_t rollback_iter, const std::string &out_dir) { double r_check_a, r_check_b; r_check_a = MPI_Wtime(); Checkpoint_s checkpoint_read{.field = chem->getField()}; - read_checkpoint(out_dir, - "checkpoint" + std::to_string(rollback_iter) + ".hdf5", - checkpoint_read); + read_checkpoint(out_dir, "checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read); current_iter = checkpoint_read.iteration; r_check_b = MPI_Wtime(); r_check_t += r_check_b - r_check_a; } -void poet::ControlModule::writeErrorMetrics( - const std::string &out_dir, const std::vector &species) { +void poet::ControlModule::writeMetrics(const std::string &out_dir, + const std::vector &species) { + if (rb_count > config.rb_limit) { + return; + } double stats_a, stats_b; stats_a = MPI_Wtime(); @@ -93,63 +100,68 @@ void poet::ControlModule::writeErrorMetrics( this->stats_t += stats_b - stats_a; } -uint32_t poet::ControlModule::getRollbackIter() { +uint32_t poet::ControlModule::calcRbIter() { - uint32_t last_iter = ((global_iteration - 1) / config.checkpoint_interval) * - config.checkpoint_interval; + uint32_t last_iter = ((global_iter - 1) / config.chkpt_interval) * config.chkpt_interval; - uint32_t rollback_iter = (last_iter <= last_checkpoint_written) - ? last_iter - : last_checkpoint_written; - return rollback_iter; + uint32_t rb_iter = (last_iter <= last_chkpt_written) ? last_iter : last_chkpt_written; + return rb_iter; } -std::optional poet::ControlModule::getRollbackTarget( - const std::vector &species) { - double r_check_a, r_check_b; +std::optional poet::ControlModule::findRbTarget(const std::vector &species) { if (metrics_history.empty()) { - MSG("No error history yet, skipping rollback check."); + std::cout << "No error history yet, skipping rollback check." << std::endl; flush_request = false; return std::nullopt; } - if (rollback_count > 3) { - MSG("Rollback limit reached, skipping rollback."); + if (rb_count > config.rb_limit) { + std::cout << "Rollback limit reached, skipping control logic." << std::endl; flush_request = false; return std::nullopt; } + std::cout << "findRbTarget called at iter " << global_iter << ", rb_count=" << rb_count + << ", rb_limit=" << config.rb_limit << std::endl; + + double r_check_a, r_check_b; const auto &mape = metrics_history.back().mape; for (uint32_t i = 0; i < species.size(); ++i) { - // skip Charge - if (i == 4 || mape[i] == 0) { + + if (mape[i] == 0) { continue; } if (mape[i] > config.mape_threshold[i]) { - if (last_checkpoint_written == 0) { - MSG(" Threshold exceeded but no checkpoint exists yet."); + std::cout << "Species " << species[i] << " MAPE=" << mape[i] + << " threshold=" << config.mape_threshold[i] << std::endl; + + if (last_chkpt_written == 0) { + std::cout << " Threshold exceeded but no checkpoint exists yet." << std::endl; return std::nullopt; } - + // rb_enabled = true; flush_request = true; - - MSG("T hreshold exceeded " + species[i] + - " has MAPE = " + std::to_string(mape[i]) + - " exceeding threshold = " + std::to_string(config.mape_threshold[i])); - return getRollbackIter(); + std::cout << "Threshold exceeded " << species[i] << " has MAPE = " << std::to_string(mape[i]) + << " exceeding threshold = " << std::to_string(config.mape_threshold[i]) + << std::endl; + return calcRbIter(); } } - MSG("All species are within their MAPE thresholds."); + // std::cout << "All species are within their MAPE thresholds." << std::endl; flush_request = false; return std::nullopt; } -void poet::ControlModule::computeErrorMetrics( - const std::vector &reference_values, - const std::vector &surrogate_values, const uint32_t size_per_prop, - const std::vector &species) { +void poet::ControlModule::computeMetrics(const std::vector &reference_values, + const std::vector &surrogate_values, + const uint32_t size_per_prop, + const std::vector &species) { - SpeciesErrorMetrics metrics(species.size(), global_iteration, rollback_count); + if (rb_count > config.rb_limit) { + return; + } + + SpeciesMetrics metrics(species.size(), global_iter, rb_count); for (uint32_t i = 0; i < species.size(); ++i) { double err_sum = 0.0; @@ -164,60 +176,50 @@ void poet::ControlModule::computeErrorMetrics( if (std::isnan(ref_value) || std::isnan(sur_value)) { continue; } - - if (!std::isfinite(ref_value) || !std::isfinite(sur_value)) { - continue; - } - - if (std::abs(ref_value) == ZERO_ABS) { - if (std::abs(sur_value) != ZERO_ABS) { + if (std::abs(ref_value) < ZERO_ABS) { + if (std::abs(sur_value) >= ZERO_ABS) { err_sum += 1.0; sqr_err_sum += 1.0; } - } - // Both zero: skip - else { + } else { double alpha = 1.0 - (sur_value / ref_value); - if (!std::isfinite(alpha)) { - continue; // protects against inf/NaN due to extreme values - } - err_sum += std::abs(alpha); sqr_err_sum += alpha * alpha; } } - metrics.mape[i] = 100.0 * (err_sum / static_cast(size_per_prop)); - metrics.rrmse[i] = - std::sqrt(sqr_err_sum / static_cast(size_per_prop)); + metrics.mape[i] = 100.0 * (err_sum / size_per_prop); + metrics.rrmse[i] = std::sqrt(sqr_err_sum / size_per_prop); } metrics_history.push_back(metrics); } -void poet::ControlModule::processCheckpoint( - uint32_t ¤t_iter, const std::string &out_dir, - const std::vector &species) { +void poet::ControlModule::processCheckpoint(uint32_t ¤t_iter, const std::string &out_dir, + const std::vector &species) { - if (!control_interval_enabled) + if (!ctrl_active || rb_count > config.rb_limit) { return; + } - if (flush_request && rollback_count < 3) { - uint32_t target = getRollbackIter(); + if (flush_request) { + uint32_t target = calcRbIter(); readCheckpoint(current_iter, target, out_dir); - rollback_enabled = true; - rollback_count++; - disable_surr_counter = config.control_interval; + rb_enabled = true; + rb_count++; + stab_countdown = config.ctrl_interval; - MSG("Restored checkpoint " + std::to_string(target) + - ", surrogates disabled for " + std::to_string(config.control_interval)); + std::cout << "Restored checkpoint " << std::to_string(target) << ", surrogates disabled for " + << config.ctrl_interval << std::endl; } else { - writeCheckpoint(global_iteration, out_dir); + writeCheckpoint(global_iter, out_dir); } } -bool poet::ControlModule::shouldBcastFlags() const { - if (global_iteration == 1 || - global_iteration % config.control_interval == 1) { +bool poet::ControlModule::needsFlagBcast() const { + if (rb_count > config.rb_limit) { + return false; + } + if (global_iter == 1 || global_iter % config.ctrl_interval == 1) { return true; } return false; diff --git a/src/Control/ControlModule.hpp b/src/Control/ControlModule.hpp index 29deec61c..ea8222c7e 100644 --- a/src/Control/ControlModule.hpp +++ b/src/Control/ControlModule.hpp @@ -14,21 +14,22 @@ namespace poet { class ChemistryModule; struct ControlConfig { - uint32_t control_interval = 0; - uint32_t checkpoint_interval = 0; + uint32_t ctrl_interval = 0; + uint32_t chkpt_interval = 0; + uint32_t rb_limit = 0; double zero_abs = 0.0; std::vector mape_threshold; }; -struct SpeciesErrorMetrics { +struct SpeciesMetrics { std::vector mape; std::vector rrmse; uint32_t iteration = 0; - uint32_t rollback_count = 0; + uint32_t rb_count = 0; - SpeciesErrorMetrics(uint32_t n_species, uint32_t iter, uint32_t rb_count) + SpeciesMetrics(uint32_t n_species, uint32_t iter, uint32_t count) : mape(n_species, 0.0), rrmse(n_species, 0.0), iteration(iter), - rollback_count(rb_count) {} + rb_count(count) {} }; class ControlModule { @@ -40,12 +41,12 @@ public: void writeCheckpoint(uint32_t &iter, const std::string &out_dir); - void writeErrorMetrics(const std::string &out_dir, + void writeMetrics(const std::string &out_dir, const std::vector &species); - std::optional getRollbackTarget(); + std::optional findRbTarget(); - void computeErrorMetrics(const std::vector &reference_values, + void computeMetrics(const std::vector &reference_values, const std::vector &surrogate_values, const uint32_t size_per_prop, const std::vector &species); @@ -55,11 +56,11 @@ public: const std::vector &species); std::optional - getRollbackTarget(const std::vector &species); + findRbTarget(const std::vector &species); - bool shouldBcastFlags() const; - bool getControlIntervalEnabled() const { - return this->control_interval_enabled; + bool needsFlagBcast() const; + bool isCtrlIntervalActive() const { + return this->ctrl_active; } bool getFlushRequest() const { return flush_request; } @@ -67,34 +68,32 @@ public: /* Profiling getters */ - double getUpdateCtrlLogicTime() const { return prep_t; } - double getWriteCheckpointTime() const { return w_check_t; } - double getReadCheckpointTime() const { return r_check_t; } - double getWriteMetricsTime() const { return stats_t; } + double getCtrlLogicTime() const { return prep_t; } + double getChkptWriteTime() const { return w_check_t; } + double getChkptReadTime() const { return r_check_t; } + double getMetricsWriteTime() const { return stats_t; } private: - void updateStabilizationPhase(bool dht_enabled, bool interp_enabled); + void updateSurrState(bool dht_enabled, bool interp_enabled); void readCheckpoint(uint32_t ¤t_iter, uint32_t rollback_iter, const std::string &out_dir); - uint32_t getRollbackIter(); + uint32_t calcRbIter(); ControlConfig config; ChemistryModule *chem = nullptr; - std::uint32_t global_iteration = 0; - std::uint32_t rollback_count = 0; - std::uint32_t disable_surr_counter = 0; - std::uint32_t last_checkpoint_written = 0; + std::uint32_t global_iter = 0; + std::uint32_t rb_count = 0; + std::uint32_t stab_countdown = 0; + std::uint32_t last_chkpt_written = 0; - bool rollback_enabled = false; - bool control_interval_enabled = false; + bool rb_enabled = false; + bool ctrl_active = false; bool flush_request = false; - bool bcast_flags = false; - - std::vector metrics_history; + std::vector metrics_history; double prep_t = 0.; double r_check_t = 0.; diff --git a/src/IO/StatsIO.cpp b/src/IO/StatsIO.cpp index 06af6ae86..ad3cd4a9d 100644 --- a/src/IO/StatsIO.cpp +++ b/src/IO/StatsIO.cpp @@ -7,7 +7,7 @@ namespace poet { - void writeStatsToCSV(const std::vector &all_stats, + void writeStatsToCSV(const std::vector &all_stats, const std::vector &species_names, const std::string &out_dir, const std::string &filename) @@ -37,7 +37,7 @@ namespace poet { out << std::left << std::setw(15) << all_stats[i].iteration - << std::setw(15) << all_stats[i].rollback_count + << std::setw(15) << all_stats[i].rb_count << std::setw(15) << species_names[j] << std::setw(15) << all_stats[i].mape[j] << std::setw(15) << all_stats[i].rrmse[j] diff --git a/src/IO/StatsIO.hpp b/src/IO/StatsIO.hpp index 512ab0d4c..eaa8d3675 100644 --- a/src/IO/StatsIO.hpp +++ b/src/IO/StatsIO.hpp @@ -1,10 +1,7 @@ -#include #include "Control/ControlModule.hpp" +#include -namespace poet -{ - void writeStatsToCSV(const std::vector &all_stats, - const std::vector &species_names, - const std::string &out_dir, - const std::string &filename); +namespace poet { +void writeStatsToCSV(const std::vector &all_stats, const std::vector &species_names, + const std::string &out_dir, const std::string &filename); } // namespace poet diff --git a/src/poet.cpp b/src/poet.cpp index bae9363f6..3db547fc2 100644 --- a/src/poet.cpp +++ b/src/poet.cpp @@ -249,10 +249,12 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) { params.timesteps = Rcpp::as>(global_rt_setup->operator[]("timesteps")); - params.control_interval = - Rcpp::as(global_rt_setup->operator[]("control_interval")); - params.checkpoint_interval = - Rcpp::as(global_rt_setup->operator[]("checkpoint_interval")); + params.ctrl_interval = + Rcpp::as(global_rt_setup->operator[]("ctrl_interval")); + params.chkpt_interval = + Rcpp::as(global_rt_setup->operator[]("chkpt_interval")); + params.rb_limit = + Rcpp::as(global_rt_setup->operator[]("rb_limit")); params.mape_threshold = Rcpp::as>( global_rt_setup->operator[]("mape_threshold")); params.zero_abs = Rcpp::as(global_rt_setup->operator[]("zero_abs")); @@ -411,9 +413,10 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, MSG("End of *coupling* iteration " + std::to_string(iter) + "/" + std::to_string(maxiter)); - if (control.getControlIntervalEnabled()) { - control.processCheckpoint(iter, params.out_dir, chem.getField().GetProps()); - control.writeErrorMetrics(params.out_dir, chem.getField().GetProps()); + if (control.isCtrlIntervalActive()) { + control.processCheckpoint(iter, params.out_dir, + chem.getField().GetProps()); + control.writeMetrics(params.out_dir, chem.getField().GetProps()); } // MSG(); } // END SIMULATION LOOP @@ -434,10 +437,10 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, Rcpp::List ctrl_profiling; ctrl_profiling["compute_metrics_master"] = chem.GetMasterCtrlMetricsTime(); ctrl_profiling["unshuffle_field_master"] = chem.GetMasterUnshuffleTime(); - ctrl_profiling["w_checkpoint_master"] = control.getWriteCheckpointTime(); - ctrl_profiling["r_checkpoint_master"] = control.getReadCheckpointTime(); - ctrl_profiling["write_stats"] = control.getWriteMetricsTime(); - ctrl_profiling["ctrl_logic_master"] = control.getUpdateCtrlLogicTime(); + ctrl_profiling["w_checkpoint_master"] = control.getChkptWriteTime(); + ctrl_profiling["r_checkpoint_master"] = control.getChkptReadTime(); + ctrl_profiling["write_stats"] = control.getMetricsWriteTime(); + ctrl_profiling["ctrl_logic_master"] = control.getCtrlLogicTime(); ctrl_profiling["recv_data_master"] = chem.GetMasterRecvCtrlDataTime(); ctrl_profiling["worker"] = Rcpp::wrap(chem.GetWorkerControlTimings()); @@ -629,6 +632,14 @@ int main(int argc, char *argv[]) { chemistry.masterEnableSurrogates(surr_setup); + ControlConfig config(run_params.ctrl_interval, run_params.chkpt_interval, + run_params.rb_limit, run_params.zero_abs, + run_params.mape_threshold); + + ControlModule control(config, &chemistry); + + chemistry.SetControlModule(&control); + if (MY_RANK > 0) { chemistry.WorkerLoop(); } else { @@ -672,14 +683,6 @@ int main(int argc, char *argv[]) { chemistry.masterSetField(init_list.getInitialGrid()); - ControlConfig config(run_params.control_interval, - run_params.checkpoint_interval, run_params.zero_abs, - run_params.mape_threshold); - - ControlModule control(config, &chemistry); - - chemistry.SetControlModule(&control); - Rcpp::List profiling = RunMasterLoop(R, run_params, diffusion, chemistry, control); diff --git a/src/poet.hpp.in b/src/poet.hpp.in index bd9bd9ce3..678aaafbc 100644 --- a/src/poet.hpp.in +++ b/src/poet.hpp.in @@ -51,8 +51,9 @@ struct RuntimeParameters { bool print_progress = false; - std::uint32_t checkpoint_interval = 0; - std::uint32_t control_interval = 0; + std::uint32_t chkpt_interval = 0; + std::uint32_t ctrl_interval = 0; + std::uint32_t rb_limit = 0; std::vector mape_threshold; double zero_abs = 0.0;