added rb_limit, correctedd updateSurrState logic, added eraly return after rb_limit reached

This commit is contained in:
rastogi 2025-11-28 12:58:26 +01:00
parent 9087393f61
commit 97076cb7cd
12 changed files with 225 additions and 240 deletions

View File

@ -1,21 +1,24 @@
iterations <- 15000
iterations <- 10000
dt <- 200
checkpoint_interval <- 100
control_interval <- 100
chkpt_interval <- 100
ctrl_interval <- 100
mape_threshold <- rep(0.0035, 13)
zero_abs <- 0.0
#mape_threshold[5] <- 1 #Charge
mape_threshold[5] <- 1 #Charge
zero_abs <- 1e-13
rb_limit <- 3
#ctrl_cell_ids <- seq(0, (400*400)/2 - 1, by = 401)
#out_save <- seq(500, iterations, by = 500)
#out_save = c(seq(1, 10), seq(10, 100, by= 10), seq(200, iterations, by=100))
out_save = c(seq(1, 10), seq(10, 100, by= 10), seq(200, iterations, by=100))
list(
timesteps = rep(dt, iterations),
store_result = FALSE,
#out_save = out_save,
checkpoint_interval = checkpoint_interval,
control_interval = control_interval,
store_result = TRUE,
out_save = out_save,
chkpt_interval = chkpt_interval,
ctrl_interval = ctrl_interval,
mape_threshold = mape_threshold,
zero_abs = zero_abs
zero_abs = zero_abs,
rb_limit = rb_limit
)

View File

@ -58,7 +58,7 @@ all_data <- lapply(args, function(stats_file) {
})
combined_data <- bind_rows(all_data) %>%
filter(Iteration <= 3000) %>%
filter(Iteration >= 3000 & Iteration <= 8000) %>%
filter(is.finite(MedianMAPE) & MedianMAPE > 0) %>%
filter(is.finite(MaxMAPE) & MaxMAPE > 0)

View File

@ -1,7 +1,7 @@
#!/bin/bash
#SBATCH --job-name=proto1_only_interp_zeroabs
#SBATCH --output=proto1_only_interp_zeroabs_%j.out
#SBATCH --error=proto1_only_interp_zeroabs_%j.err
#SBATCH --job-name=p1_eps0035_v2
#SBATCH --output=p1_eps0035_v2_%j.out
#SBATCH --error=p1_eps0035_v2_%j.err
#SBATCH --partition=long
#SBATCH --nodes=6
#SBATCH --ntasks-per-node=24
@ -15,5 +15,5 @@ module purge
module load cmake gcc openmpi
#mpirun -n 144 ./poet dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_only_pqc
mpirun -n 144 ./poet --interp dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 proto1_only_interp_zeroabs
mpirun -n 144 ./poet --interp dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 p1_eps0035_v2
#mpirun -n 144 ./poet --interp barite_fgcs_4_new/barite_fgcs_4_new_rt.R barite_fgcs_4_new/barite_fgcs_4_new.qs2 barite

View File

@ -102,8 +102,6 @@ public:
this->base_totals = setup.base_totals;
this->ctr_file_out_dir = setup.dht_out_dir;
if (this->dht_enabled || this->interp_enabled) {
this->initializeDHT(setup.dht_size_mb, this->params.dht_species,
setup.has_het_ids);
@ -269,6 +267,8 @@ public:
void SetStabEnabled(bool enabled) { stab_enabled = enabled; }
inline uint32_t buildCtrlFlags(bool dht, bool interp, bool stab);
protected:
void initializeDHT(uint32_t size_mb,
const NamedVector<std::uint32_t> &key_species,
@ -417,20 +417,6 @@ protected:
ChemBCast(&type, 1, MPI_INT);
}
std::string ctr_file_out_dir;
inline int buildControlPacket(bool dht, bool interp, bool stab) {
int flags = 0;
if (dht)
flags |= DHT_ENABLE;
if (interp)
flags |= IP_ENABLE;
if (stab)
flags |= STAB_ENABLE;
return flags;
}
inline bool hasFlag(int flags, int type) { return (flags & type) != 0; }
double simtime = 0.;
@ -464,7 +450,7 @@ protected:
std::vector<double> mpi_surr_buffer;
bool control_enabled{false};
bool ctrl_enabled{false};
bool stab_enabled{false};
};
} // namespace poet

View File

@ -232,6 +232,17 @@ inline void printProgressbar(int count_pkgs, int n_wp, int barWidth = 70) {
/* end visual progress */
}
inline uint32_t poet::ChemistryModule::buildCtrlFlags(bool dht, bool interp, bool stab) {
uint32_t flags = 0;
if (dht)
flags |= DHT_ENABLE;
if (interp)
flags |= IP_ENABLE;
if (stab)
flags |= STAB_ENABLE;
return flags;
}
inline void poet::ChemistryModule::MasterSendPkgs(
worker_list_t &w_list, workpointer_t &work_pointer,
workpointer_t &sur_pointer, int &pkg_to_send, int &count_pkgs,
@ -257,7 +268,7 @@ inline void poet::ChemistryModule::MasterSendPkgs(
/* note current processed work package in workerlist */
w_list[p].send_addr = work_pointer.base();
w_list[p].surrogate_addr = sur_pointer.base();
// this->control_enabled ? sur_pointer.base() : w_list[p].surrogate_addr =
// this->ctrl_enabled ? sur_pointer.base() : w_list[p].surrogate_addr =
// nullptr;
/* push work pointer to next work package */
@ -282,7 +293,7 @@ inline void poet::ChemistryModule::MasterSendPkgs(
// control flags (bitmask)
/* int flags = (this->interp_enabled ? 1 : 0) | (this->dht_enabled ? 2 :
0) | (this->warmup_enabled ? 4 : 0) | (this->control_enabled ? 8 : 0);
0) | (this->warmup_enabled ? 4 : 0) | (this->ctrl_enabled ? 8 : 0);
send_buffer[end_of_wp + 5] = static_cast<double>(flags);
*/
@ -449,19 +460,17 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
/* broadcast control state once every iteration */
ftype = CHEM_CTRL_ENABLE;
PropagateFunctionType(ftype);
int ctrl =
(this->control_enabled = this->control->getControlIntervalEnabled()) ? 1
: 0;
ChemBCast(&ctrl, 1, MPI_INT);
uint32_t ctrl = (ctrl_enabled = control->isCtrlIntervalActive()) ? 1 : 0;
ChemBCast(&ctrl, 1, MPI_UINT32_T);
if (control->shouldBcastFlags()) {
if (control->needsFlagBcast()) {
int ftype = CHEM_CTRL_FLAGS;
PropagateFunctionType(ftype);
uint32_t ctrl_flags = buildControlPacket(
this->dht_enabled, this->interp_enabled, this->stab_enabled);
ChemBCast(&ctrl_flags, 1, MPI_INT);
uint32_t ctrl_flags =
buildCtrlFlags(dht_enabled, interp_enabled, stab_enabled);
ChemBCast(&ctrl_flags, 1, MPI_UINT32_T);
this->mpi_surr_buffer.assign(this->n_cells * this->prop_count, 0.0);
mpi_surr_buffer.assign(n_cells * prop_count, 0.0);
}
ftype = CHEM_WORK_LOOP;
@ -481,8 +490,8 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
wp_sizes_vector.size());
// Only resize surrogate buffer if control is enabled
if (this->control_enabled) {
this->mpi_surr_buffer.resize(mpi_buffer.size());
if (ctrl_enabled) {
mpi_surr_buffer.resize(mpi_buffer.size());
}
/* setup local variables */
@ -490,9 +499,8 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
pkg_to_recv = wp_sizes_vector.size();
workpointer_t work_pointer = mpi_buffer.begin();
workpointer_t sur_pointer = this->mpi_surr_buffer.begin();
//(this->control_enabled ? this->mpi_surr_buffer.begin()
// : mpi_buffer.end());
workpointer_t sur_pointer = mpi_surr_buffer.begin();
worker_list_t worker_list(this->comm_size - 1);
free_workers = this->comm_size - 1;
@ -540,14 +548,12 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
chem_field = out_vec;
/* do master stuff */
if (this->control_enabled) {
std::cout << "[Master] Control logic enabled for this iteration."
<< std::endl;
if (ctrl_enabled) {
std::vector<double> sur_unshuffled{mpi_surr_buffer};
shuf_a = MPI_Wtime();
unshuffleField(this->mpi_surr_buffer, this->n_cells, this->prop_count,
wp_sizes_vector.size(), sur_unshuffled);
unshuffleField(mpi_surr_buffer, n_cells, prop_count, wp_sizes_vector.size(),
sur_unshuffled);
shuf_b = MPI_Wtime();
this->shuf_t += shuf_b - shuf_a;
@ -558,8 +564,7 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
}
metrics_a = MPI_Wtime();
control->computeErrorMetrics(out_vec, sur_unshuffled, this->n_cells,
prop_names);
control->computeMetrics(out_vec, sur_unshuffled, n_cells, prop_names);
metrics_b = MPI_Wtime();
this->metrics_t += metrics_b - metrics_a;
@ -575,10 +580,10 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
/* end time measurement of whole chemistry simulation */
std::optional<uint32_t> target = std::nullopt;
if (this->control_enabled) {
target = control->getRollbackTarget(prop_names);
if (ctrl_enabled) {
target = control->findRbTarget(prop_names);
}
int flush = (this->control_enabled && target.has_value()) ? 1 : 0;
int flush = (ctrl_enabled && target.has_value()) ? 1 : 0;
/* advise workers to end chemistry iteration */
for (int i = 1; i < this->comm_size; i++) {

View File

@ -60,17 +60,17 @@ void poet::ChemistryModule::WorkerLoop() {
break;
}
case CHEM_CTRL_ENABLE: {
int ctrl = 0;
ChemBCast(&ctrl, 1, MPI_INT);
this->control_enabled = (ctrl == 1);
uint32_t ctrl = 0;
ChemBCast(&ctrl, 1, MPI_UINT32_T);
ctrl_enabled = (ctrl == 1);
break;
}
case CHEM_CTRL_FLAGS: {
int flags = 0;
ChemBCast(&flags, 1, MPI_INT);
this->dht_enabled = hasFlag(flags, DHT_ENABLE);
this->interp_enabled = hasFlag(flags, IP_ENABLE);
this->stab_enabled = hasFlag(flags, STAB_ENABLE);
uint32_t flags = 0;
ChemBCast(&flags, 1, MPI_UINT32_T);
dht_enabled = hasFlag(flags, DHT_ENABLE);
interp_enabled = hasFlag(flags, IP_ENABLE);
stab_enabled = hasFlag(flags, STAB_ENABLE);
break;
}
case CHEM_WORK_LOOP: {
@ -204,21 +204,10 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
// current work package start location in field
wp_start_index = mpi_buffer[count + 4];
// read packed control flags
/*
flags = static_cast<int>(mpi_buffer[count + 5]);
this->interp_enabled = (flags & 1) != 0;
this->dht_enabled = (flags & 2) != 0;
this->warmup_enabled = (flags & 4) != 0;
this->control_enabled = (flags & 8) != 0;
*/
/*
std::cout << "warmup_enabled is " << stab_enabled << ", control_enabled is "
<< control_enabled << ", dht_enabled is " << dht_enabled
std::cout << "warmup_enabled is " << stab_enabled << ", ctrl_enabled is "
<< ctrl_enabled << ", dht_enabled is " << dht_enabled
<< ", interp_enabled is " << interp_enabled << std::endl;
*/
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
@ -258,7 +247,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
poet::WorkPackage s_curr_wp_control = s_curr_wp;
if (control_enabled) {
if (ctrl_enabled) {
ctrl_cp_start = MPI_Wtime();
for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) {
s_curr_wp_control.output[wp_i] =
@ -271,12 +260,12 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
phreeqc_time_start = MPI_Wtime();
WorkerRunWorkPackage(control_enabled ? s_curr_wp_control : s_curr_wp,
WorkerRunWorkPackage(ctrl_enabled ? s_curr_wp_control : s_curr_wp,
current_sim_time, dt);
phreeqc_time_end = MPI_Wtime();
if (control_enabled) {
if (ctrl_enabled) {
ctrl_start = MPI_Wtime();
copyCtrlPkgs(s_curr_wp_control, s_curr_wp, mpi_buffer, count);
ctrl_end = MPI_Wtime();
@ -288,14 +277,14 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
/* send results to master */
MPI_Request send_req;
int mpi_tag = control_enabled ? LOOP_CTRL : LOOP_WORK;
int mpi_tag = ctrl_enabled ? LOOP_CTRL : LOOP_WORK;
MPI_Isend(mpi_buffer.data(), count, MPI_DOUBLE, 0, mpi_tag, MPI_COMM_WORLD,
&send_req);
if (dht_enabled || interp_enabled || stab_enabled) {
/* write results to DHT */
dht_fill_start = MPI_Wtime();
dht->fillDHT(control_enabled ? s_curr_wp_control : s_curr_wp);
dht->fillDHT(ctrl_enabled ? s_curr_wp_control : s_curr_wp);
dht_fill_end = MPI_Wtime();
if (interp_enabled || stab_enabled) {

View File

@ -4,60 +4,67 @@
#include "IO/StatsIO.hpp"
#include <cmath>
poet::ControlModule::ControlModule(const ControlConfig &config_,
ChemistryModule *chem_)
poet::ControlModule::ControlModule(const ControlConfig &config_, ChemistryModule *chem_)
: config(config_), chem(chem_) {
assert(chem && "ChemistryModule pointer must not be null");
}
void poet::ControlModule::beginIteration(const uint32_t &iter,
const bool &dht_enabled,
void poet::ControlModule::beginIteration(const uint32_t &iter, const bool &dht_enabled,
const bool &interp_enabled) {
/* dht_enabled and inter_enabled are user settings set before startig the
* simulation*/
double prep_a, prep_b;
prep_a = MPI_Wtime();
if (config.control_interval == 0) {
control_interval_enabled = false;
if (config.ctrl_interval == 0) {
ctrl_active = false;
return;
}
global_iteration = iter;
global_iter = iter;
updateStabilizationPhase(dht_enabled, interp_enabled);
updateSurrState(dht_enabled, interp_enabled);
control_interval_enabled =
(config.control_interval > 0 && (iter % config.control_interval == 0));
ctrl_active = (config.ctrl_interval > 0 && (iter % config.ctrl_interval == 0));
prep_b = MPI_Wtime();
this->prep_t += prep_b - prep_a;
}
void poet::ControlModule::updateStabilizationPhase(bool dht_enabled,
bool interp_enabled) {
if (rollback_enabled) {
if (disable_surr_counter > 0) {
--disable_surr_counter;
MSG("Rollback counter: " + std::to_string(disable_surr_counter));
} else {
rollback_enabled = false;
/* manages the overall surrogate state, by enabling/disabling state based on
* warmup logic and rollback conditions*/
void poet::ControlModule::updateSurrState(bool dht_enabled, bool interp_enabled) {
bool in_warmup = (global_iter <= config.ctrl_interval);
bool rb_limit_reached = (rb_count >= config.rb_limit);
if (rb_enabled && stab_countdown > 0 && !rb_limit_reached) {
--stab_countdown;
std::cout << "Rollback counter: " << stab_countdown << std::endl;
if (stab_countdown == 0) {
rb_enabled = false;
}
flush_request = false;
}
// user requested DHT/INTEP? keep them disabled but enable warmup-phase so
if (global_iteration <= config.control_interval || rollback_enabled) {
chem->SetStabEnabled(true);
/* disable surrogates during warmup, active rollback or after limit */
if (in_warmup || rb_enabled || rb_limit_reached) {
chem->SetStabEnabled(!rb_limit_reached);
chem->SetDhtEnabled(false);
chem->SetInterpEnabled(false);
if (rb_limit_reached) {
std::cout << "Interpolation completly disabled." << std::endl;
} else {
std::cout << "In stabilization phase." << std::endl;
}
return;
}
/* enable user-requested surrogates */
chem->SetStabEnabled(false);
chem->SetDhtEnabled(dht_enabled);
chem->SetInterpEnabled(interp_enabled);
std::cout << "Interpolating." << std::endl;
}
void poet::ControlModule::writeCheckpoint(uint32_t &iter,
const std::string &out_dir) {
void poet::ControlModule::writeCheckpoint(uint32_t &iter, const std::string &out_dir) {
double w_check_a, w_check_b;
w_check_a = MPI_Wtime();
write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
@ -65,25 +72,25 @@ void poet::ControlModule::writeCheckpoint(uint32_t &iter,
w_check_b = MPI_Wtime();
this->w_check_t += w_check_b - w_check_a;
last_checkpoint_written = iter;
last_chkpt_written = iter;
}
void poet::ControlModule::readCheckpoint(uint32_t &current_iter,
uint32_t rollback_iter,
void poet::ControlModule::readCheckpoint(uint32_t &current_iter, uint32_t rollback_iter,
const std::string &out_dir) {
double r_check_a, r_check_b;
r_check_a = MPI_Wtime();
Checkpoint_s checkpoint_read{.field = chem->getField()};
read_checkpoint(out_dir,
"checkpoint" + std::to_string(rollback_iter) + ".hdf5",
checkpoint_read);
read_checkpoint(out_dir, "checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read);
current_iter = checkpoint_read.iteration;
r_check_b = MPI_Wtime();
r_check_t += r_check_b - r_check_a;
}
void poet::ControlModule::writeErrorMetrics(
const std::string &out_dir, const std::vector<std::string> &species) {
void poet::ControlModule::writeMetrics(const std::string &out_dir,
const std::vector<std::string> &species) {
if (rb_count > config.rb_limit) {
return;
}
double stats_a, stats_b;
stats_a = MPI_Wtime();
@ -93,63 +100,68 @@ void poet::ControlModule::writeErrorMetrics(
this->stats_t += stats_b - stats_a;
}
uint32_t poet::ControlModule::getRollbackIter() {
uint32_t poet::ControlModule::calcRbIter() {
uint32_t last_iter = ((global_iteration - 1) / config.checkpoint_interval) *
config.checkpoint_interval;
uint32_t last_iter = ((global_iter - 1) / config.chkpt_interval) * config.chkpt_interval;
uint32_t rollback_iter = (last_iter <= last_checkpoint_written)
? last_iter
: last_checkpoint_written;
return rollback_iter;
uint32_t rb_iter = (last_iter <= last_chkpt_written) ? last_iter : last_chkpt_written;
return rb_iter;
}
std::optional<uint32_t> poet::ControlModule::getRollbackTarget(
const std::vector<std::string> &species) {
double r_check_a, r_check_b;
std::optional<uint32_t> poet::ControlModule::findRbTarget(const std::vector<std::string> &species) {
if (metrics_history.empty()) {
MSG("No error history yet, skipping rollback check.");
std::cout << "No error history yet, skipping rollback check." << std::endl;
flush_request = false;
return std::nullopt;
}
if (rollback_count > 3) {
MSG("Rollback limit reached, skipping rollback.");
if (rb_count > config.rb_limit) {
std::cout << "Rollback limit reached, skipping control logic." << std::endl;
flush_request = false;
return std::nullopt;
}
std::cout << "findRbTarget called at iter " << global_iter << ", rb_count=" << rb_count
<< ", rb_limit=" << config.rb_limit << std::endl;
double r_check_a, r_check_b;
const auto &mape = metrics_history.back().mape;
for (uint32_t i = 0; i < species.size(); ++i) {
// skip Charge
if (i == 4 || mape[i] == 0) {
if (mape[i] == 0) {
continue;
}
if (mape[i] > config.mape_threshold[i]) {
if (last_checkpoint_written == 0) {
MSG(" Threshold exceeded but no checkpoint exists yet.");
std::cout << "Species " << species[i] << " MAPE=" << mape[i]
<< " threshold=" << config.mape_threshold[i] << std::endl;
if (last_chkpt_written == 0) {
std::cout << " Threshold exceeded but no checkpoint exists yet." << std::endl;
return std::nullopt;
}
// rb_enabled = true;
flush_request = true;
MSG("T hreshold exceeded " + species[i] +
" has MAPE = " + std::to_string(mape[i]) +
" exceeding threshold = " + std::to_string(config.mape_threshold[i]));
return getRollbackIter();
std::cout << "Threshold exceeded " << species[i] << " has MAPE = " << std::to_string(mape[i])
<< " exceeding threshold = " << std::to_string(config.mape_threshold[i])
<< std::endl;
return calcRbIter();
}
}
MSG("All species are within their MAPE thresholds.");
// std::cout << "All species are within their MAPE thresholds." << std::endl;
flush_request = false;
return std::nullopt;
}
void poet::ControlModule::computeErrorMetrics(
const std::vector<double> &reference_values,
const std::vector<double> &surrogate_values, const uint32_t size_per_prop,
void poet::ControlModule::computeMetrics(const std::vector<double> &reference_values,
const std::vector<double> &surrogate_values,
const uint32_t size_per_prop,
const std::vector<std::string> &species) {
SpeciesErrorMetrics metrics(species.size(), global_iteration, rollback_count);
if (rb_count > config.rb_limit) {
return;
}
SpeciesMetrics metrics(species.size(), global_iter, rb_count);
for (uint32_t i = 0; i < species.size(); ++i) {
double err_sum = 0.0;
@ -164,60 +176,50 @@ void poet::ControlModule::computeErrorMetrics(
if (std::isnan(ref_value) || std::isnan(sur_value)) {
continue;
}
if (!std::isfinite(ref_value) || !std::isfinite(sur_value)) {
continue;
}
if (std::abs(ref_value) == ZERO_ABS) {
if (std::abs(sur_value) != ZERO_ABS) {
if (std::abs(ref_value) < ZERO_ABS) {
if (std::abs(sur_value) >= ZERO_ABS) {
err_sum += 1.0;
sqr_err_sum += 1.0;
}
}
// Both zero: skip
else {
} else {
double alpha = 1.0 - (sur_value / ref_value);
if (!std::isfinite(alpha)) {
continue; // protects against inf/NaN due to extreme values
}
err_sum += std::abs(alpha);
sqr_err_sum += alpha * alpha;
}
}
metrics.mape[i] = 100.0 * (err_sum / static_cast<double>(size_per_prop));
metrics.rrmse[i] =
std::sqrt(sqr_err_sum / static_cast<double>(size_per_prop));
metrics.mape[i] = 100.0 * (err_sum / size_per_prop);
metrics.rrmse[i] = std::sqrt(sqr_err_sum / size_per_prop);
}
metrics_history.push_back(metrics);
}
void poet::ControlModule::processCheckpoint(
uint32_t &current_iter, const std::string &out_dir,
void poet::ControlModule::processCheckpoint(uint32_t &current_iter, const std::string &out_dir,
const std::vector<std::string> &species) {
if (!control_interval_enabled)
if (!ctrl_active || rb_count > config.rb_limit) {
return;
}
if (flush_request && rollback_count < 3) {
uint32_t target = getRollbackIter();
if (flush_request) {
uint32_t target = calcRbIter();
readCheckpoint(current_iter, target, out_dir);
rollback_enabled = true;
rollback_count++;
disable_surr_counter = config.control_interval;
rb_enabled = true;
rb_count++;
stab_countdown = config.ctrl_interval;
MSG("Restored checkpoint " + std::to_string(target) +
", surrogates disabled for " + std::to_string(config.control_interval));
std::cout << "Restored checkpoint " << std::to_string(target) << ", surrogates disabled for "
<< config.ctrl_interval << std::endl;
} else {
writeCheckpoint(global_iteration, out_dir);
writeCheckpoint(global_iter, out_dir);
}
}
bool poet::ControlModule::shouldBcastFlags() const {
if (global_iteration == 1 ||
global_iteration % config.control_interval == 1) {
bool poet::ControlModule::needsFlagBcast() const {
if (rb_count > config.rb_limit) {
return false;
}
if (global_iter == 1 || global_iter % config.ctrl_interval == 1) {
return true;
}
return false;

View File

@ -14,21 +14,22 @@ namespace poet {
class ChemistryModule;
struct ControlConfig {
uint32_t control_interval = 0;
uint32_t checkpoint_interval = 0;
uint32_t ctrl_interval = 0;
uint32_t chkpt_interval = 0;
uint32_t rb_limit = 0;
double zero_abs = 0.0;
std::vector<double> mape_threshold;
};
struct SpeciesErrorMetrics {
struct SpeciesMetrics {
std::vector<double> mape;
std::vector<double> rrmse;
uint32_t iteration = 0;
uint32_t rollback_count = 0;
uint32_t rb_count = 0;
SpeciesErrorMetrics(uint32_t n_species, uint32_t iter, uint32_t rb_count)
SpeciesMetrics(uint32_t n_species, uint32_t iter, uint32_t count)
: mape(n_species, 0.0), rrmse(n_species, 0.0), iteration(iter),
rollback_count(rb_count) {}
rb_count(count) {}
};
class ControlModule {
@ -40,12 +41,12 @@ public:
void writeCheckpoint(uint32_t &iter, const std::string &out_dir);
void writeErrorMetrics(const std::string &out_dir,
void writeMetrics(const std::string &out_dir,
const std::vector<std::string> &species);
std::optional<uint32_t> getRollbackTarget();
std::optional<uint32_t> findRbTarget();
void computeErrorMetrics(const std::vector<double> &reference_values,
void computeMetrics(const std::vector<double> &reference_values,
const std::vector<double> &surrogate_values,
const uint32_t size_per_prop,
const std::vector<std::string> &species);
@ -55,11 +56,11 @@ public:
const std::vector<std::string> &species);
std::optional<uint32_t>
getRollbackTarget(const std::vector<std::string> &species);
findRbTarget(const std::vector<std::string> &species);
bool shouldBcastFlags() const;
bool getControlIntervalEnabled() const {
return this->control_interval_enabled;
bool needsFlagBcast() const;
bool isCtrlIntervalActive() const {
return this->ctrl_active;
}
bool getFlushRequest() const { return flush_request; }
@ -67,34 +68,32 @@ public:
/* Profiling getters */
double getUpdateCtrlLogicTime() const { return prep_t; }
double getWriteCheckpointTime() const { return w_check_t; }
double getReadCheckpointTime() const { return r_check_t; }
double getWriteMetricsTime() const { return stats_t; }
double getCtrlLogicTime() const { return prep_t; }
double getChkptWriteTime() const { return w_check_t; }
double getChkptReadTime() const { return r_check_t; }
double getMetricsWriteTime() const { return stats_t; }
private:
void updateStabilizationPhase(bool dht_enabled, bool interp_enabled);
void updateSurrState(bool dht_enabled, bool interp_enabled);
void readCheckpoint(uint32_t &current_iter,
uint32_t rollback_iter, const std::string &out_dir);
uint32_t getRollbackIter();
uint32_t calcRbIter();
ControlConfig config;
ChemistryModule *chem = nullptr;
std::uint32_t global_iteration = 0;
std::uint32_t rollback_count = 0;
std::uint32_t disable_surr_counter = 0;
std::uint32_t last_checkpoint_written = 0;
std::uint32_t global_iter = 0;
std::uint32_t rb_count = 0;
std::uint32_t stab_countdown = 0;
std::uint32_t last_chkpt_written = 0;
bool rollback_enabled = false;
bool control_interval_enabled = false;
bool rb_enabled = false;
bool ctrl_active = false;
bool flush_request = false;
bool bcast_flags = false;
std::vector<SpeciesErrorMetrics> metrics_history;
std::vector<SpeciesMetrics> metrics_history;
double prep_t = 0.;
double r_check_t = 0.;

View File

@ -7,7 +7,7 @@
namespace poet
{
void writeStatsToCSV(const std::vector<SpeciesErrorMetrics> &all_stats,
void writeStatsToCSV(const std::vector<SpeciesMetrics> &all_stats,
const std::vector<std::string> &species_names,
const std::string &out_dir,
const std::string &filename)
@ -37,7 +37,7 @@ namespace poet
{
out << std::left
<< std::setw(15) << all_stats[i].iteration
<< std::setw(15) << all_stats[i].rollback_count
<< std::setw(15) << all_stats[i].rb_count
<< std::setw(15) << species_names[j]
<< std::setw(15) << all_stats[i].mape[j]
<< std::setw(15) << all_stats[i].rrmse[j]

View File

@ -1,10 +1,7 @@
#include <string>
#include "Control/ControlModule.hpp"
#include <string>
namespace poet
{
void writeStatsToCSV(const std::vector<SpeciesErrorMetrics> &all_stats,
const std::vector<std::string> &species_names,
const std::string &out_dir,
const std::string &filename);
namespace poet {
void writeStatsToCSV(const std::vector<SpeciesMetrics> &all_stats, const std::vector<std::string> &species_names,
const std::string &out_dir, const std::string &filename);
} // namespace poet

View File

@ -249,10 +249,12 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
params.timesteps =
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("timesteps"));
params.control_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("control_interval"));
params.checkpoint_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval"));
params.ctrl_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("ctrl_interval"));
params.chkpt_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("chkpt_interval"));
params.rb_limit =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("rb_limit"));
params.mape_threshold = Rcpp::as<std::vector<double>>(
global_rt_setup->operator[]("mape_threshold"));
params.zero_abs = Rcpp::as<double>(global_rt_setup->operator[]("zero_abs"));
@ -411,9 +413,10 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +
std::to_string(maxiter));
if (control.getControlIntervalEnabled()) {
control.processCheckpoint(iter, params.out_dir, chem.getField().GetProps());
control.writeErrorMetrics(params.out_dir, chem.getField().GetProps());
if (control.isCtrlIntervalActive()) {
control.processCheckpoint(iter, params.out_dir,
chem.getField().GetProps());
control.writeMetrics(params.out_dir, chem.getField().GetProps());
}
// MSG();
} // END SIMULATION LOOP
@ -434,10 +437,10 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
Rcpp::List ctrl_profiling;
ctrl_profiling["compute_metrics_master"] = chem.GetMasterCtrlMetricsTime();
ctrl_profiling["unshuffle_field_master"] = chem.GetMasterUnshuffleTime();
ctrl_profiling["w_checkpoint_master"] = control.getWriteCheckpointTime();
ctrl_profiling["r_checkpoint_master"] = control.getReadCheckpointTime();
ctrl_profiling["write_stats"] = control.getWriteMetricsTime();
ctrl_profiling["ctrl_logic_master"] = control.getUpdateCtrlLogicTime();
ctrl_profiling["w_checkpoint_master"] = control.getChkptWriteTime();
ctrl_profiling["r_checkpoint_master"] = control.getChkptReadTime();
ctrl_profiling["write_stats"] = control.getMetricsWriteTime();
ctrl_profiling["ctrl_logic_master"] = control.getCtrlLogicTime();
ctrl_profiling["recv_data_master"] = chem.GetMasterRecvCtrlDataTime();
ctrl_profiling["worker"] = Rcpp::wrap(chem.GetWorkerControlTimings());
@ -629,6 +632,14 @@ int main(int argc, char *argv[]) {
chemistry.masterEnableSurrogates(surr_setup);
ControlConfig config(run_params.ctrl_interval, run_params.chkpt_interval,
run_params.rb_limit, run_params.zero_abs,
run_params.mape_threshold);
ControlModule control(config, &chemistry);
chemistry.SetControlModule(&control);
if (MY_RANK > 0) {
chemistry.WorkerLoop();
} else {
@ -672,14 +683,6 @@ int main(int argc, char *argv[]) {
chemistry.masterSetField(init_list.getInitialGrid());
ControlConfig config(run_params.control_interval,
run_params.checkpoint_interval, run_params.zero_abs,
run_params.mape_threshold);
ControlModule control(config, &chemistry);
chemistry.SetControlModule(&control);
Rcpp::List profiling =
RunMasterLoop(R, run_params, diffusion, chemistry, control);

View File

@ -51,8 +51,9 @@ struct RuntimeParameters {
bool print_progress = false;
std::uint32_t checkpoint_interval = 0;
std::uint32_t control_interval = 0;
std::uint32_t chkpt_interval = 0;
std::uint32_t ctrl_interval = 0;
std::uint32_t rb_limit = 0;
std::vector<double> mape_threshold;
double zero_abs = 0.0;