added rb_limit, correctedd updateSurrState logic, added eraly return after rb_limit reached

This commit is contained in:
rastogi 2025-11-28 12:58:26 +01:00
parent 9087393f61
commit 97076cb7cd
12 changed files with 225 additions and 240 deletions

View File

@ -1,21 +1,24 @@
iterations <- 15000 iterations <- 10000
dt <- 200 dt <- 200
checkpoint_interval <- 100 chkpt_interval <- 100
control_interval <- 100 ctrl_interval <- 100
mape_threshold <- rep(0.0035, 13) mape_threshold <- rep(0.0035, 13)
zero_abs <- 0.0 mape_threshold[5] <- 1 #Charge
#mape_threshold[5] <- 1 #Charge zero_abs <- 1e-13
rb_limit <- 3
#ctrl_cell_ids <- seq(0, (400*400)/2 - 1, by = 401) #ctrl_cell_ids <- seq(0, (400*400)/2 - 1, by = 401)
#out_save <- seq(500, iterations, by = 500) #out_save <- seq(500, iterations, by = 500)
#out_save = c(seq(1, 10), seq(10, 100, by= 10), seq(200, iterations, by=100)) out_save = c(seq(1, 10), seq(10, 100, by= 10), seq(200, iterations, by=100))
list( list(
timesteps = rep(dt, iterations), timesteps = rep(dt, iterations),
store_result = FALSE, store_result = TRUE,
#out_save = out_save, out_save = out_save,
checkpoint_interval = checkpoint_interval, chkpt_interval = chkpt_interval,
control_interval = control_interval, ctrl_interval = ctrl_interval,
mape_threshold = mape_threshold, mape_threshold = mape_threshold,
zero_abs = zero_abs zero_abs = zero_abs,
rb_limit = rb_limit
) )

View File

@ -58,7 +58,7 @@ all_data <- lapply(args, function(stats_file) {
}) })
combined_data <- bind_rows(all_data) %>% combined_data <- bind_rows(all_data) %>%
filter(Iteration <= 3000) %>% filter(Iteration >= 3000 & Iteration <= 8000) %>%
filter(is.finite(MedianMAPE) & MedianMAPE > 0) %>% filter(is.finite(MedianMAPE) & MedianMAPE > 0) %>%
filter(is.finite(MaxMAPE) & MaxMAPE > 0) filter(is.finite(MaxMAPE) & MaxMAPE > 0)

View File

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
#SBATCH --job-name=proto1_only_interp_zeroabs #SBATCH --job-name=p1_eps0035_v2
#SBATCH --output=proto1_only_interp_zeroabs_%j.out #SBATCH --output=p1_eps0035_v2_%j.out
#SBATCH --error=proto1_only_interp_zeroabs_%j.err #SBATCH --error=p1_eps0035_v2_%j.err
#SBATCH --partition=long #SBATCH --partition=long
#SBATCH --nodes=6 #SBATCH --nodes=6
#SBATCH --ntasks-per-node=24 #SBATCH --ntasks-per-node=24
@ -15,5 +15,5 @@ module purge
module load cmake gcc openmpi module load cmake gcc openmpi
#mpirun -n 144 ./poet dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_only_pqc #mpirun -n 144 ./poet dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_only_pqc
mpirun -n 144 ./poet --interp dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 proto1_only_interp_zeroabs mpirun -n 144 ./poet --interp dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 p1_eps0035_v2
#mpirun -n 144 ./poet --interp barite_fgcs_4_new/barite_fgcs_4_new_rt.R barite_fgcs_4_new/barite_fgcs_4_new.qs2 barite #mpirun -n 144 ./poet --interp barite_fgcs_4_new/barite_fgcs_4_new_rt.R barite_fgcs_4_new/barite_fgcs_4_new.qs2 barite

View File

@ -102,8 +102,6 @@ public:
this->base_totals = setup.base_totals; this->base_totals = setup.base_totals;
this->ctr_file_out_dir = setup.dht_out_dir;
if (this->dht_enabled || this->interp_enabled) { if (this->dht_enabled || this->interp_enabled) {
this->initializeDHT(setup.dht_size_mb, this->params.dht_species, this->initializeDHT(setup.dht_size_mb, this->params.dht_species,
setup.has_het_ids); setup.has_het_ids);
@ -269,6 +267,8 @@ public:
void SetStabEnabled(bool enabled) { stab_enabled = enabled; } void SetStabEnabled(bool enabled) { stab_enabled = enabled; }
inline uint32_t buildCtrlFlags(bool dht, bool interp, bool stab);
protected: protected:
void initializeDHT(uint32_t size_mb, void initializeDHT(uint32_t size_mb,
const NamedVector<std::uint32_t> &key_species, const NamedVector<std::uint32_t> &key_species,
@ -384,10 +384,10 @@ protected:
void BCastStringVec(std::vector<std::string> &io); void BCastStringVec(std::vector<std::string> &io);
void copyPkgs(const WorkPackage &wp, std::vector<double> &mpi_buffer, void copyPkgs(const WorkPackage &wp, std::vector<double> &mpi_buffer,
std::size_t offset = 0); std::size_t offset = 0);
void copyCtrlPkgs(const WorkPackage &pqc_wp, const WorkPackage &surr_wp, void copyCtrlPkgs(const WorkPackage &pqc_wp, const WorkPackage &surr_wp,
std::vector<double> &mpi_bufffer, int &count); std::vector<double> &mpi_bufffer, int &count);
int comm_size, comm_rank; int comm_size, comm_rank;
MPI_Comm group_comm; MPI_Comm group_comm;
@ -417,20 +417,6 @@ protected:
ChemBCast(&type, 1, MPI_INT); ChemBCast(&type, 1, MPI_INT);
} }
std::string ctr_file_out_dir;
inline int buildControlPacket(bool dht, bool interp, bool stab) {
int flags = 0;
if (dht)
flags |= DHT_ENABLE;
if (interp)
flags |= IP_ENABLE;
if (stab)
flags |= STAB_ENABLE;
return flags;
}
inline bool hasFlag(int flags, int type) { return (flags & type) != 0; } inline bool hasFlag(int flags, int type) { return (flags & type) != 0; }
double simtime = 0.; double simtime = 0.;
@ -464,7 +450,7 @@ protected:
std::vector<double> mpi_surr_buffer; std::vector<double> mpi_surr_buffer;
bool control_enabled{false}; bool ctrl_enabled{false};
bool stab_enabled{false}; bool stab_enabled{false};
}; };
} // namespace poet } // namespace poet

View File

@ -232,6 +232,17 @@ inline void printProgressbar(int count_pkgs, int n_wp, int barWidth = 70) {
/* end visual progress */ /* end visual progress */
} }
inline uint32_t poet::ChemistryModule::buildCtrlFlags(bool dht, bool interp, bool stab) {
uint32_t flags = 0;
if (dht)
flags |= DHT_ENABLE;
if (interp)
flags |= IP_ENABLE;
if (stab)
flags |= STAB_ENABLE;
return flags;
}
inline void poet::ChemistryModule::MasterSendPkgs( inline void poet::ChemistryModule::MasterSendPkgs(
worker_list_t &w_list, workpointer_t &work_pointer, worker_list_t &w_list, workpointer_t &work_pointer,
workpointer_t &sur_pointer, int &pkg_to_send, int &count_pkgs, workpointer_t &sur_pointer, int &pkg_to_send, int &count_pkgs,
@ -257,7 +268,7 @@ inline void poet::ChemistryModule::MasterSendPkgs(
/* note current processed work package in workerlist */ /* note current processed work package in workerlist */
w_list[p].send_addr = work_pointer.base(); w_list[p].send_addr = work_pointer.base();
w_list[p].surrogate_addr = sur_pointer.base(); w_list[p].surrogate_addr = sur_pointer.base();
// this->control_enabled ? sur_pointer.base() : w_list[p].surrogate_addr = // this->ctrl_enabled ? sur_pointer.base() : w_list[p].surrogate_addr =
// nullptr; // nullptr;
/* push work pointer to next work package */ /* push work pointer to next work package */
@ -282,7 +293,7 @@ inline void poet::ChemistryModule::MasterSendPkgs(
// control flags (bitmask) // control flags (bitmask)
/* int flags = (this->interp_enabled ? 1 : 0) | (this->dht_enabled ? 2 : /* int flags = (this->interp_enabled ? 1 : 0) | (this->dht_enabled ? 2 :
0) | (this->warmup_enabled ? 4 : 0) | (this->control_enabled ? 8 : 0); 0) | (this->warmup_enabled ? 4 : 0) | (this->ctrl_enabled ? 8 : 0);
send_buffer[end_of_wp + 5] = static_cast<double>(flags); send_buffer[end_of_wp + 5] = static_cast<double>(flags);
*/ */
@ -449,19 +460,17 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
/* broadcast control state once every iteration */ /* broadcast control state once every iteration */
ftype = CHEM_CTRL_ENABLE; ftype = CHEM_CTRL_ENABLE;
PropagateFunctionType(ftype); PropagateFunctionType(ftype);
int ctrl = uint32_t ctrl = (ctrl_enabled = control->isCtrlIntervalActive()) ? 1 : 0;
(this->control_enabled = this->control->getControlIntervalEnabled()) ? 1 ChemBCast(&ctrl, 1, MPI_UINT32_T);
: 0;
ChemBCast(&ctrl, 1, MPI_INT);
if (control->shouldBcastFlags()) { if (control->needsFlagBcast()) {
int ftype = CHEM_CTRL_FLAGS; int ftype = CHEM_CTRL_FLAGS;
PropagateFunctionType(ftype); PropagateFunctionType(ftype);
uint32_t ctrl_flags = buildControlPacket( uint32_t ctrl_flags =
this->dht_enabled, this->interp_enabled, this->stab_enabled); buildCtrlFlags(dht_enabled, interp_enabled, stab_enabled);
ChemBCast(&ctrl_flags, 1, MPI_INT); ChemBCast(&ctrl_flags, 1, MPI_UINT32_T);
this->mpi_surr_buffer.assign(this->n_cells * this->prop_count, 0.0); mpi_surr_buffer.assign(n_cells * prop_count, 0.0);
} }
ftype = CHEM_WORK_LOOP; ftype = CHEM_WORK_LOOP;
@ -481,8 +490,8 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
wp_sizes_vector.size()); wp_sizes_vector.size());
// Only resize surrogate buffer if control is enabled // Only resize surrogate buffer if control is enabled
if (this->control_enabled) { if (ctrl_enabled) {
this->mpi_surr_buffer.resize(mpi_buffer.size()); mpi_surr_buffer.resize(mpi_buffer.size());
} }
/* setup local variables */ /* setup local variables */
@ -490,9 +499,8 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
pkg_to_recv = wp_sizes_vector.size(); pkg_to_recv = wp_sizes_vector.size();
workpointer_t work_pointer = mpi_buffer.begin(); workpointer_t work_pointer = mpi_buffer.begin();
workpointer_t sur_pointer = this->mpi_surr_buffer.begin(); workpointer_t sur_pointer = mpi_surr_buffer.begin();
//(this->control_enabled ? this->mpi_surr_buffer.begin()
// : mpi_buffer.end());
worker_list_t worker_list(this->comm_size - 1); worker_list_t worker_list(this->comm_size - 1);
free_workers = this->comm_size - 1; free_workers = this->comm_size - 1;
@ -540,14 +548,12 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
chem_field = out_vec; chem_field = out_vec;
/* do master stuff */ /* do master stuff */
if (this->control_enabled) { if (ctrl_enabled) {
std::cout << "[Master] Control logic enabled for this iteration."
<< std::endl;
std::vector<double> sur_unshuffled{mpi_surr_buffer}; std::vector<double> sur_unshuffled{mpi_surr_buffer};
shuf_a = MPI_Wtime(); shuf_a = MPI_Wtime();
unshuffleField(this->mpi_surr_buffer, this->n_cells, this->prop_count, unshuffleField(mpi_surr_buffer, n_cells, prop_count, wp_sizes_vector.size(),
wp_sizes_vector.size(), sur_unshuffled); sur_unshuffled);
shuf_b = MPI_Wtime(); shuf_b = MPI_Wtime();
this->shuf_t += shuf_b - shuf_a; this->shuf_t += shuf_b - shuf_a;
@ -558,8 +564,7 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
} }
metrics_a = MPI_Wtime(); metrics_a = MPI_Wtime();
control->computeErrorMetrics(out_vec, sur_unshuffled, this->n_cells, control->computeMetrics(out_vec, sur_unshuffled, n_cells, prop_names);
prop_names);
metrics_b = MPI_Wtime(); metrics_b = MPI_Wtime();
this->metrics_t += metrics_b - metrics_a; this->metrics_t += metrics_b - metrics_a;
@ -575,10 +580,10 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
/* end time measurement of whole chemistry simulation */ /* end time measurement of whole chemistry simulation */
std::optional<uint32_t> target = std::nullopt; std::optional<uint32_t> target = std::nullopt;
if (this->control_enabled) { if (ctrl_enabled) {
target = control->getRollbackTarget(prop_names); target = control->findRbTarget(prop_names);
} }
int flush = (this->control_enabled && target.has_value()) ? 1 : 0; int flush = (ctrl_enabled && target.has_value()) ? 1 : 0;
/* advise workers to end chemistry iteration */ /* advise workers to end chemistry iteration */
for (int i = 1; i < this->comm_size; i++) { for (int i = 1; i < this->comm_size; i++) {

View File

@ -60,17 +60,17 @@ void poet::ChemistryModule::WorkerLoop() {
break; break;
} }
case CHEM_CTRL_ENABLE: { case CHEM_CTRL_ENABLE: {
int ctrl = 0; uint32_t ctrl = 0;
ChemBCast(&ctrl, 1, MPI_INT); ChemBCast(&ctrl, 1, MPI_UINT32_T);
this->control_enabled = (ctrl == 1); ctrl_enabled = (ctrl == 1);
break; break;
} }
case CHEM_CTRL_FLAGS: { case CHEM_CTRL_FLAGS: {
int flags = 0; uint32_t flags = 0;
ChemBCast(&flags, 1, MPI_INT); ChemBCast(&flags, 1, MPI_UINT32_T);
this->dht_enabled = hasFlag(flags, DHT_ENABLE); dht_enabled = hasFlag(flags, DHT_ENABLE);
this->interp_enabled = hasFlag(flags, IP_ENABLE); interp_enabled = hasFlag(flags, IP_ENABLE);
this->stab_enabled = hasFlag(flags, STAB_ENABLE); stab_enabled = hasFlag(flags, STAB_ENABLE);
break; break;
} }
case CHEM_WORK_LOOP: { case CHEM_WORK_LOOP: {
@ -204,21 +204,10 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
// current work package start location in field // current work package start location in field
wp_start_index = mpi_buffer[count + 4]; wp_start_index = mpi_buffer[count + 4];
// read packed control flags
/* /*
flags = static_cast<int>(mpi_buffer[count + 5]); std::cout << "warmup_enabled is " << stab_enabled << ", ctrl_enabled is "
this->interp_enabled = (flags & 1) != 0; << ctrl_enabled << ", dht_enabled is " << dht_enabled
this->dht_enabled = (flags & 2) != 0;
this->warmup_enabled = (flags & 4) != 0;
this->control_enabled = (flags & 8) != 0;
*/
/*
std::cout << "warmup_enabled is " << stab_enabled << ", control_enabled is "
<< control_enabled << ", dht_enabled is " << dht_enabled
<< ", interp_enabled is " << interp_enabled << std::endl; << ", interp_enabled is " << interp_enabled << std::endl;
*/ */
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) { for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
@ -258,7 +247,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
poet::WorkPackage s_curr_wp_control = s_curr_wp; poet::WorkPackage s_curr_wp_control = s_curr_wp;
if (control_enabled) { if (ctrl_enabled) {
ctrl_cp_start = MPI_Wtime(); ctrl_cp_start = MPI_Wtime();
for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) { for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) {
s_curr_wp_control.output[wp_i] = s_curr_wp_control.output[wp_i] =
@ -271,12 +260,12 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
phreeqc_time_start = MPI_Wtime(); phreeqc_time_start = MPI_Wtime();
WorkerRunWorkPackage(control_enabled ? s_curr_wp_control : s_curr_wp, WorkerRunWorkPackage(ctrl_enabled ? s_curr_wp_control : s_curr_wp,
current_sim_time, dt); current_sim_time, dt);
phreeqc_time_end = MPI_Wtime(); phreeqc_time_end = MPI_Wtime();
if (control_enabled) { if (ctrl_enabled) {
ctrl_start = MPI_Wtime(); ctrl_start = MPI_Wtime();
copyCtrlPkgs(s_curr_wp_control, s_curr_wp, mpi_buffer, count); copyCtrlPkgs(s_curr_wp_control, s_curr_wp, mpi_buffer, count);
ctrl_end = MPI_Wtime(); ctrl_end = MPI_Wtime();
@ -288,14 +277,14 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
/* send results to master */ /* send results to master */
MPI_Request send_req; MPI_Request send_req;
int mpi_tag = control_enabled ? LOOP_CTRL : LOOP_WORK; int mpi_tag = ctrl_enabled ? LOOP_CTRL : LOOP_WORK;
MPI_Isend(mpi_buffer.data(), count, MPI_DOUBLE, 0, mpi_tag, MPI_COMM_WORLD, MPI_Isend(mpi_buffer.data(), count, MPI_DOUBLE, 0, mpi_tag, MPI_COMM_WORLD,
&send_req); &send_req);
if (dht_enabled || interp_enabled || stab_enabled) { if (dht_enabled || interp_enabled || stab_enabled) {
/* write results to DHT */ /* write results to DHT */
dht_fill_start = MPI_Wtime(); dht_fill_start = MPI_Wtime();
dht->fillDHT(control_enabled ? s_curr_wp_control : s_curr_wp); dht->fillDHT(ctrl_enabled ? s_curr_wp_control : s_curr_wp);
dht_fill_end = MPI_Wtime(); dht_fill_end = MPI_Wtime();
if (interp_enabled || stab_enabled) { if (interp_enabled || stab_enabled) {

View File

@ -4,60 +4,67 @@
#include "IO/StatsIO.hpp" #include "IO/StatsIO.hpp"
#include <cmath> #include <cmath>
poet::ControlModule::ControlModule(const ControlConfig &config_, poet::ControlModule::ControlModule(const ControlConfig &config_, ChemistryModule *chem_)
ChemistryModule *chem_)
: config(config_), chem(chem_) { : config(config_), chem(chem_) {
assert(chem && "ChemistryModule pointer must not be null"); assert(chem && "ChemistryModule pointer must not be null");
} }
void poet::ControlModule::beginIteration(const uint32_t &iter, void poet::ControlModule::beginIteration(const uint32_t &iter, const bool &dht_enabled,
const bool &dht_enabled,
const bool &interp_enabled) { const bool &interp_enabled) {
/* dht_enabled and inter_enabled are user settings set before startig the
* simulation*/
double prep_a, prep_b; double prep_a, prep_b;
prep_a = MPI_Wtime(); prep_a = MPI_Wtime();
if (config.control_interval == 0) { if (config.ctrl_interval == 0) {
control_interval_enabled = false; ctrl_active = false;
return; return;
} }
global_iteration = iter; global_iter = iter;
updateStabilizationPhase(dht_enabled, interp_enabled); updateSurrState(dht_enabled, interp_enabled);
control_interval_enabled = ctrl_active = (config.ctrl_interval > 0 && (iter % config.ctrl_interval == 0));
(config.control_interval > 0 && (iter % config.control_interval == 0));
prep_b = MPI_Wtime(); prep_b = MPI_Wtime();
this->prep_t += prep_b - prep_a; this->prep_t += prep_b - prep_a;
} }
void poet::ControlModule::updateStabilizationPhase(bool dht_enabled, /* manages the overall surrogate state, by enabling/disabling state based on
bool interp_enabled) { * warmup logic and rollback conditions*/
if (rollback_enabled) { void poet::ControlModule::updateSurrState(bool dht_enabled, bool interp_enabled) {
if (disable_surr_counter > 0) {
--disable_surr_counter; bool in_warmup = (global_iter <= config.ctrl_interval);
MSG("Rollback counter: " + std::to_string(disable_surr_counter)); bool rb_limit_reached = (rb_count >= config.rb_limit);
} else {
rollback_enabled = false; if (rb_enabled && stab_countdown > 0 && !rb_limit_reached) {
--stab_countdown;
std::cout << "Rollback counter: " << stab_countdown << std::endl;
if (stab_countdown == 0) {
rb_enabled = false;
} }
flush_request = false;
} }
// user requested DHT/INTEP? keep them disabled but enable warmup-phase so /* disable surrogates during warmup, active rollback or after limit */
if (global_iteration <= config.control_interval || rollback_enabled) { if (in_warmup || rb_enabled || rb_limit_reached) {
chem->SetStabEnabled(true); chem->SetStabEnabled(!rb_limit_reached);
chem->SetDhtEnabled(false); chem->SetDhtEnabled(false);
chem->SetInterpEnabled(false); chem->SetInterpEnabled(false);
if (rb_limit_reached) {
std::cout << "Interpolation completly disabled." << std::endl;
} else {
std::cout << "In stabilization phase." << std::endl;
}
return; return;
} }
/* enable user-requested surrogates */
chem->SetStabEnabled(false); chem->SetStabEnabled(false);
chem->SetDhtEnabled(dht_enabled); chem->SetDhtEnabled(dht_enabled);
chem->SetInterpEnabled(interp_enabled); chem->SetInterpEnabled(interp_enabled);
std::cout << "Interpolating." << std::endl;
} }
void poet::ControlModule::writeCheckpoint(uint32_t &iter, void poet::ControlModule::writeCheckpoint(uint32_t &iter, const std::string &out_dir) {
const std::string &out_dir) {
double w_check_a, w_check_b; double w_check_a, w_check_b;
w_check_a = MPI_Wtime(); w_check_a = MPI_Wtime();
write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5", write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
@ -65,25 +72,25 @@ void poet::ControlModule::writeCheckpoint(uint32_t &iter,
w_check_b = MPI_Wtime(); w_check_b = MPI_Wtime();
this->w_check_t += w_check_b - w_check_a; this->w_check_t += w_check_b - w_check_a;
last_checkpoint_written = iter; last_chkpt_written = iter;
} }
void poet::ControlModule::readCheckpoint(uint32_t &current_iter, void poet::ControlModule::readCheckpoint(uint32_t &current_iter, uint32_t rollback_iter,
uint32_t rollback_iter,
const std::string &out_dir) { const std::string &out_dir) {
double r_check_a, r_check_b; double r_check_a, r_check_b;
r_check_a = MPI_Wtime(); r_check_a = MPI_Wtime();
Checkpoint_s checkpoint_read{.field = chem->getField()}; Checkpoint_s checkpoint_read{.field = chem->getField()};
read_checkpoint(out_dir, read_checkpoint(out_dir, "checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read);
"checkpoint" + std::to_string(rollback_iter) + ".hdf5",
checkpoint_read);
current_iter = checkpoint_read.iteration; current_iter = checkpoint_read.iteration;
r_check_b = MPI_Wtime(); r_check_b = MPI_Wtime();
r_check_t += r_check_b - r_check_a; r_check_t += r_check_b - r_check_a;
} }
void poet::ControlModule::writeErrorMetrics( void poet::ControlModule::writeMetrics(const std::string &out_dir,
const std::string &out_dir, const std::vector<std::string> &species) { const std::vector<std::string> &species) {
if (rb_count > config.rb_limit) {
return;
}
double stats_a, stats_b; double stats_a, stats_b;
stats_a = MPI_Wtime(); stats_a = MPI_Wtime();
@ -93,63 +100,68 @@ void poet::ControlModule::writeErrorMetrics(
this->stats_t += stats_b - stats_a; this->stats_t += stats_b - stats_a;
} }
uint32_t poet::ControlModule::getRollbackIter() { uint32_t poet::ControlModule::calcRbIter() {
uint32_t last_iter = ((global_iteration - 1) / config.checkpoint_interval) * uint32_t last_iter = ((global_iter - 1) / config.chkpt_interval) * config.chkpt_interval;
config.checkpoint_interval;
uint32_t rollback_iter = (last_iter <= last_checkpoint_written) uint32_t rb_iter = (last_iter <= last_chkpt_written) ? last_iter : last_chkpt_written;
? last_iter return rb_iter;
: last_checkpoint_written;
return rollback_iter;
} }
std::optional<uint32_t> poet::ControlModule::getRollbackTarget( std::optional<uint32_t> poet::ControlModule::findRbTarget(const std::vector<std::string> &species) {
const std::vector<std::string> &species) {
double r_check_a, r_check_b;
if (metrics_history.empty()) { if (metrics_history.empty()) {
MSG("No error history yet, skipping rollback check."); std::cout << "No error history yet, skipping rollback check." << std::endl;
flush_request = false; flush_request = false;
return std::nullopt; return std::nullopt;
} }
if (rollback_count > 3) { if (rb_count > config.rb_limit) {
MSG("Rollback limit reached, skipping rollback."); std::cout << "Rollback limit reached, skipping control logic." << std::endl;
flush_request = false; flush_request = false;
return std::nullopt; return std::nullopt;
} }
std::cout << "findRbTarget called at iter " << global_iter << ", rb_count=" << rb_count
<< ", rb_limit=" << config.rb_limit << std::endl;
double r_check_a, r_check_b;
const auto &mape = metrics_history.back().mape; const auto &mape = metrics_history.back().mape;
for (uint32_t i = 0; i < species.size(); ++i) { for (uint32_t i = 0; i < species.size(); ++i) {
// skip Charge
if (i == 4 || mape[i] == 0) { if (mape[i] == 0) {
continue; continue;
} }
if (mape[i] > config.mape_threshold[i]) { if (mape[i] > config.mape_threshold[i]) {
if (last_checkpoint_written == 0) { std::cout << "Species " << species[i] << " MAPE=" << mape[i]
MSG(" Threshold exceeded but no checkpoint exists yet."); << " threshold=" << config.mape_threshold[i] << std::endl;
if (last_chkpt_written == 0) {
std::cout << " Threshold exceeded but no checkpoint exists yet." << std::endl;
return std::nullopt; return std::nullopt;
} }
// rb_enabled = true;
flush_request = true; flush_request = true;
std::cout << "Threshold exceeded " << species[i] << " has MAPE = " << std::to_string(mape[i])
MSG("T hreshold exceeded " + species[i] + << " exceeding threshold = " << std::to_string(config.mape_threshold[i])
" has MAPE = " + std::to_string(mape[i]) + << std::endl;
" exceeding threshold = " + std::to_string(config.mape_threshold[i])); return calcRbIter();
return getRollbackIter();
} }
} }
MSG("All species are within their MAPE thresholds."); // std::cout << "All species are within their MAPE thresholds." << std::endl;
flush_request = false; flush_request = false;
return std::nullopt; return std::nullopt;
} }
void poet::ControlModule::computeErrorMetrics( void poet::ControlModule::computeMetrics(const std::vector<double> &reference_values,
const std::vector<double> &reference_values, const std::vector<double> &surrogate_values,
const std::vector<double> &surrogate_values, const uint32_t size_per_prop, const uint32_t size_per_prop,
const std::vector<std::string> &species) { const std::vector<std::string> &species) {
SpeciesErrorMetrics metrics(species.size(), global_iteration, rollback_count); if (rb_count > config.rb_limit) {
return;
}
SpeciesMetrics metrics(species.size(), global_iter, rb_count);
for (uint32_t i = 0; i < species.size(); ++i) { for (uint32_t i = 0; i < species.size(); ++i) {
double err_sum = 0.0; double err_sum = 0.0;
@ -164,60 +176,50 @@ void poet::ControlModule::computeErrorMetrics(
if (std::isnan(ref_value) || std::isnan(sur_value)) { if (std::isnan(ref_value) || std::isnan(sur_value)) {
continue; continue;
} }
if (std::abs(ref_value) < ZERO_ABS) {
if (!std::isfinite(ref_value) || !std::isfinite(sur_value)) { if (std::abs(sur_value) >= ZERO_ABS) {
continue;
}
if (std::abs(ref_value) == ZERO_ABS) {
if (std::abs(sur_value) != ZERO_ABS) {
err_sum += 1.0; err_sum += 1.0;
sqr_err_sum += 1.0; sqr_err_sum += 1.0;
} }
} } else {
// Both zero: skip
else {
double alpha = 1.0 - (sur_value / ref_value); double alpha = 1.0 - (sur_value / ref_value);
if (!std::isfinite(alpha)) {
continue; // protects against inf/NaN due to extreme values
}
err_sum += std::abs(alpha); err_sum += std::abs(alpha);
sqr_err_sum += alpha * alpha; sqr_err_sum += alpha * alpha;
} }
} }
metrics.mape[i] = 100.0 * (err_sum / static_cast<double>(size_per_prop)); metrics.mape[i] = 100.0 * (err_sum / size_per_prop);
metrics.rrmse[i] = metrics.rrmse[i] = std::sqrt(sqr_err_sum / size_per_prop);
std::sqrt(sqr_err_sum / static_cast<double>(size_per_prop));
} }
metrics_history.push_back(metrics); metrics_history.push_back(metrics);
} }
void poet::ControlModule::processCheckpoint( void poet::ControlModule::processCheckpoint(uint32_t &current_iter, const std::string &out_dir,
uint32_t &current_iter, const std::string &out_dir, const std::vector<std::string> &species) {
const std::vector<std::string> &species) {
if (!control_interval_enabled) if (!ctrl_active || rb_count > config.rb_limit) {
return; return;
}
if (flush_request && rollback_count < 3) { if (flush_request) {
uint32_t target = getRollbackIter(); uint32_t target = calcRbIter();
readCheckpoint(current_iter, target, out_dir); readCheckpoint(current_iter, target, out_dir);
rollback_enabled = true; rb_enabled = true;
rollback_count++; rb_count++;
disable_surr_counter = config.control_interval; stab_countdown = config.ctrl_interval;
MSG("Restored checkpoint " + std::to_string(target) + std::cout << "Restored checkpoint " << std::to_string(target) << ", surrogates disabled for "
", surrogates disabled for " + std::to_string(config.control_interval)); << config.ctrl_interval << std::endl;
} else { } else {
writeCheckpoint(global_iteration, out_dir); writeCheckpoint(global_iter, out_dir);
} }
} }
bool poet::ControlModule::shouldBcastFlags() const { bool poet::ControlModule::needsFlagBcast() const {
if (global_iteration == 1 || if (rb_count > config.rb_limit) {
global_iteration % config.control_interval == 1) { return false;
}
if (global_iter == 1 || global_iter % config.ctrl_interval == 1) {
return true; return true;
} }
return false; return false;

View File

@ -14,21 +14,22 @@ namespace poet {
class ChemistryModule; class ChemistryModule;
struct ControlConfig { struct ControlConfig {
uint32_t control_interval = 0; uint32_t ctrl_interval = 0;
uint32_t checkpoint_interval = 0; uint32_t chkpt_interval = 0;
uint32_t rb_limit = 0;
double zero_abs = 0.0; double zero_abs = 0.0;
std::vector<double> mape_threshold; std::vector<double> mape_threshold;
}; };
struct SpeciesErrorMetrics { struct SpeciesMetrics {
std::vector<double> mape; std::vector<double> mape;
std::vector<double> rrmse; std::vector<double> rrmse;
uint32_t iteration = 0; uint32_t iteration = 0;
uint32_t rollback_count = 0; uint32_t rb_count = 0;
SpeciesErrorMetrics(uint32_t n_species, uint32_t iter, uint32_t rb_count) SpeciesMetrics(uint32_t n_species, uint32_t iter, uint32_t count)
: mape(n_species, 0.0), rrmse(n_species, 0.0), iteration(iter), : mape(n_species, 0.0), rrmse(n_species, 0.0), iteration(iter),
rollback_count(rb_count) {} rb_count(count) {}
}; };
class ControlModule { class ControlModule {
@ -40,12 +41,12 @@ public:
void writeCheckpoint(uint32_t &iter, const std::string &out_dir); void writeCheckpoint(uint32_t &iter, const std::string &out_dir);
void writeErrorMetrics(const std::string &out_dir, void writeMetrics(const std::string &out_dir,
const std::vector<std::string> &species); const std::vector<std::string> &species);
std::optional<uint32_t> getRollbackTarget(); std::optional<uint32_t> findRbTarget();
void computeErrorMetrics(const std::vector<double> &reference_values, void computeMetrics(const std::vector<double> &reference_values,
const std::vector<double> &surrogate_values, const std::vector<double> &surrogate_values,
const uint32_t size_per_prop, const uint32_t size_per_prop,
const std::vector<std::string> &species); const std::vector<std::string> &species);
@ -55,11 +56,11 @@ public:
const std::vector<std::string> &species); const std::vector<std::string> &species);
std::optional<uint32_t> std::optional<uint32_t>
getRollbackTarget(const std::vector<std::string> &species); findRbTarget(const std::vector<std::string> &species);
bool shouldBcastFlags() const; bool needsFlagBcast() const;
bool getControlIntervalEnabled() const { bool isCtrlIntervalActive() const {
return this->control_interval_enabled; return this->ctrl_active;
} }
bool getFlushRequest() const { return flush_request; } bool getFlushRequest() const { return flush_request; }
@ -67,34 +68,32 @@ public:
/* Profiling getters */ /* Profiling getters */
double getUpdateCtrlLogicTime() const { return prep_t; } double getCtrlLogicTime() const { return prep_t; }
double getWriteCheckpointTime() const { return w_check_t; } double getChkptWriteTime() const { return w_check_t; }
double getReadCheckpointTime() const { return r_check_t; } double getChkptReadTime() const { return r_check_t; }
double getWriteMetricsTime() const { return stats_t; } double getMetricsWriteTime() const { return stats_t; }
private: private:
void updateStabilizationPhase(bool dht_enabled, bool interp_enabled); void updateSurrState(bool dht_enabled, bool interp_enabled);
void readCheckpoint(uint32_t &current_iter, void readCheckpoint(uint32_t &current_iter,
uint32_t rollback_iter, const std::string &out_dir); uint32_t rollback_iter, const std::string &out_dir);
uint32_t getRollbackIter(); uint32_t calcRbIter();
ControlConfig config; ControlConfig config;
ChemistryModule *chem = nullptr; ChemistryModule *chem = nullptr;
std::uint32_t global_iteration = 0; std::uint32_t global_iter = 0;
std::uint32_t rollback_count = 0; std::uint32_t rb_count = 0;
std::uint32_t disable_surr_counter = 0; std::uint32_t stab_countdown = 0;
std::uint32_t last_checkpoint_written = 0; std::uint32_t last_chkpt_written = 0;
bool rollback_enabled = false; bool rb_enabled = false;
bool control_interval_enabled = false; bool ctrl_active = false;
bool flush_request = false; bool flush_request = false;
bool bcast_flags = false; std::vector<SpeciesMetrics> metrics_history;
std::vector<SpeciesErrorMetrics> metrics_history;
double prep_t = 0.; double prep_t = 0.;
double r_check_t = 0.; double r_check_t = 0.;

View File

@ -7,7 +7,7 @@
namespace poet namespace poet
{ {
void writeStatsToCSV(const std::vector<SpeciesErrorMetrics> &all_stats, void writeStatsToCSV(const std::vector<SpeciesMetrics> &all_stats,
const std::vector<std::string> &species_names, const std::vector<std::string> &species_names,
const std::string &out_dir, const std::string &out_dir,
const std::string &filename) const std::string &filename)
@ -37,7 +37,7 @@ namespace poet
{ {
out << std::left out << std::left
<< std::setw(15) << all_stats[i].iteration << std::setw(15) << all_stats[i].iteration
<< std::setw(15) << all_stats[i].rollback_count << std::setw(15) << all_stats[i].rb_count
<< std::setw(15) << species_names[j] << std::setw(15) << species_names[j]
<< std::setw(15) << all_stats[i].mape[j] << std::setw(15) << all_stats[i].mape[j]
<< std::setw(15) << all_stats[i].rrmse[j] << std::setw(15) << all_stats[i].rrmse[j]

View File

@ -1,10 +1,7 @@
#include <string>
#include "Control/ControlModule.hpp" #include "Control/ControlModule.hpp"
#include <string>
namespace poet namespace poet {
{ void writeStatsToCSV(const std::vector<SpeciesMetrics> &all_stats, const std::vector<std::string> &species_names,
void writeStatsToCSV(const std::vector<SpeciesErrorMetrics> &all_stats, const std::string &out_dir, const std::string &filename);
const std::vector<std::string> &species_names,
const std::string &out_dir,
const std::string &filename);
} // namespace poet } // namespace poet

View File

@ -249,10 +249,12 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
params.timesteps = params.timesteps =
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("timesteps")); Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("timesteps"));
params.control_interval = params.ctrl_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("control_interval")); Rcpp::as<uint32_t>(global_rt_setup->operator[]("ctrl_interval"));
params.checkpoint_interval = params.chkpt_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval")); Rcpp::as<uint32_t>(global_rt_setup->operator[]("chkpt_interval"));
params.rb_limit =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("rb_limit"));
params.mape_threshold = Rcpp::as<std::vector<double>>( params.mape_threshold = Rcpp::as<std::vector<double>>(
global_rt_setup->operator[]("mape_threshold")); global_rt_setup->operator[]("mape_threshold"));
params.zero_abs = Rcpp::as<double>(global_rt_setup->operator[]("zero_abs")); params.zero_abs = Rcpp::as<double>(global_rt_setup->operator[]("zero_abs"));
@ -411,9 +413,10 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" + MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +
std::to_string(maxiter)); std::to_string(maxiter));
if (control.getControlIntervalEnabled()) { if (control.isCtrlIntervalActive()) {
control.processCheckpoint(iter, params.out_dir, chem.getField().GetProps()); control.processCheckpoint(iter, params.out_dir,
control.writeErrorMetrics(params.out_dir, chem.getField().GetProps()); chem.getField().GetProps());
control.writeMetrics(params.out_dir, chem.getField().GetProps());
} }
// MSG(); // MSG();
} // END SIMULATION LOOP } // END SIMULATION LOOP
@ -434,10 +437,10 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
Rcpp::List ctrl_profiling; Rcpp::List ctrl_profiling;
ctrl_profiling["compute_metrics_master"] = chem.GetMasterCtrlMetricsTime(); ctrl_profiling["compute_metrics_master"] = chem.GetMasterCtrlMetricsTime();
ctrl_profiling["unshuffle_field_master"] = chem.GetMasterUnshuffleTime(); ctrl_profiling["unshuffle_field_master"] = chem.GetMasterUnshuffleTime();
ctrl_profiling["w_checkpoint_master"] = control.getWriteCheckpointTime(); ctrl_profiling["w_checkpoint_master"] = control.getChkptWriteTime();
ctrl_profiling["r_checkpoint_master"] = control.getReadCheckpointTime(); ctrl_profiling["r_checkpoint_master"] = control.getChkptReadTime();
ctrl_profiling["write_stats"] = control.getWriteMetricsTime(); ctrl_profiling["write_stats"] = control.getMetricsWriteTime();
ctrl_profiling["ctrl_logic_master"] = control.getUpdateCtrlLogicTime(); ctrl_profiling["ctrl_logic_master"] = control.getCtrlLogicTime();
ctrl_profiling["recv_data_master"] = chem.GetMasterRecvCtrlDataTime(); ctrl_profiling["recv_data_master"] = chem.GetMasterRecvCtrlDataTime();
ctrl_profiling["worker"] = Rcpp::wrap(chem.GetWorkerControlTimings()); ctrl_profiling["worker"] = Rcpp::wrap(chem.GetWorkerControlTimings());
@ -629,6 +632,14 @@ int main(int argc, char *argv[]) {
chemistry.masterEnableSurrogates(surr_setup); chemistry.masterEnableSurrogates(surr_setup);
ControlConfig config(run_params.ctrl_interval, run_params.chkpt_interval,
run_params.rb_limit, run_params.zero_abs,
run_params.mape_threshold);
ControlModule control(config, &chemistry);
chemistry.SetControlModule(&control);
if (MY_RANK > 0) { if (MY_RANK > 0) {
chemistry.WorkerLoop(); chemistry.WorkerLoop();
} else { } else {
@ -672,14 +683,6 @@ int main(int argc, char *argv[]) {
chemistry.masterSetField(init_list.getInitialGrid()); chemistry.masterSetField(init_list.getInitialGrid());
ControlConfig config(run_params.control_interval,
run_params.checkpoint_interval, run_params.zero_abs,
run_params.mape_threshold);
ControlModule control(config, &chemistry);
chemistry.SetControlModule(&control);
Rcpp::List profiling = Rcpp::List profiling =
RunMasterLoop(R, run_params, diffusion, chemistry, control); RunMasterLoop(R, run_params, diffusion, chemistry, control);

View File

@ -51,8 +51,9 @@ struct RuntimeParameters {
bool print_progress = false; bool print_progress = false;
std::uint32_t checkpoint_interval = 0; std::uint32_t chkpt_interval = 0;
std::uint32_t control_interval = 0; std::uint32_t ctrl_interval = 0;
std::uint32_t rb_limit = 0;
std::vector<double> mape_threshold; std::vector<double> mape_threshold;
double zero_abs = 0.0; double zero_abs = 0.0;