Compare commits

...

2 Commits

Author SHA1 Message Date
rastogi
e479f666af Enabled time measurement 2025-10-30 16:38:06 +01:00
rastogi
4ac8914175 Added time measurements 2025-10-30 15:43:34 +01:00
14 changed files with 189 additions and 322 deletions

View File

@ -105,10 +105,10 @@ setup <- list(
) )
iterations <- 250 iterations <- 100
dt <- 200 dt <- 200
checkpoint_interval <- 50 checkpoint_interval <- 20
control_interval <- 50 control_interval <- 20
mape_threshold <- rep(3.5e-3, 13) mape_threshold <- rep(3.5e-3, 13)
#out_save <- seq(50, iterations, by = 50) #out_save <- seq(50, iterations, by = 50)

View File

@ -115,18 +115,18 @@ setup <- list(
Chemistry = chemistry_setup # Parameters related to the chemistry process Chemistry = chemistry_setup # Parameters related to the chemistry process
) )
iterations <- 100 iterations <- 5000
dt <- 200 dt <- 200
checkpoint_interval <- 20 checkpoint_interval <- 100
control_interval <- 20 control_interval <- 100
mape_threshold <- rep(3.5e-3, 13) mape_threshold <- rep(3.5e-3, 13)
#out_save <- seq(50, iterations, by = 50) out_save <- seq(1000, iterations, by = 1000)
list( list(
timesteps = rep(dt, iterations), timesteps = rep(dt, iterations),
store_result = FALSE, store_result = TRUE,
#out_save = out_save, out_save = out_save,
checkpoint_interval = checkpoint_interval, checkpoint_interval = checkpoint_interval,
control_interval = control_interval, control_interval = control_interval,
mape_threshold = mape_threshold mape_threshold = mape_threshold

View File

@ -1,10 +1,10 @@
#!/bin/bash #!/bin/bash
#SBATCH --job-name=dolo_warmup_debug #SBATCH --job-name=dolo_5000
#SBATCH --output=dolo_warmup_debug%j.out #SBATCH --output=dolo_5000_%j.out
#SBATCH --error=dolo_warmup_debug%j.err #SBATCH --error=dolo_5000_%j.err
#SBATCH --partition=long #SBATCH --partition=long
#SBATCH --nodes=4 #SBATCH --nodes=6
#SBATCH --ntasks=96 #SBATCH --ntasks=144
#SBATCH --ntasks-per-node=24 #SBATCH --ntasks-per-node=24
#SBATCH --exclusive #SBATCH --exclusive
#SBATCH --time=12:00:00 #SBATCH --time=12:00:00
@ -14,5 +14,5 @@ source /etc/profile.d/modules.sh
module purge module purge
module load cmake gcc openmpi module load cmake gcc openmpi
mpirun -n 96 ./poet --interp dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_warmup_debug mpirun -n 144 ./poet --interp dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_5000
#mpirun -n 96 ./poet --interp barite_fgcs_2.R barite_fgcs_2.qs2 bar_fgcs_500_eps #mpirun -n 96 ./poet --interp barite_fgcs_2.R barite_fgcs_2.qs2 bar_warmup

Binary file not shown.

View File

@ -174,6 +174,12 @@ public:
*/ */
auto GetMasterLoopTime() const { return this->send_recv_t; } auto GetMasterLoopTime() const { return this->send_recv_t; }
auto GetMasterRecvCtrlDataTime() const { return this->recv_ctrl_t; }
auto GetMasterUnshuffleTime() const { return this->shuf_t; }
auto GetMasterCtrlMetricsTime() const { return this->metrics_t; }
/** /**
* **Master only** Collect and return all accumulated timings recorded by * **Master only** Collect and return all accumulated timings recorded by
* workers to run Phreeqc simulation. * workers to run Phreeqc simulation.
@ -404,13 +410,15 @@ protected:
ChemBCast(&type, 1, MPI_INT); ChemBCast(&type, 1, MPI_INT);
} }
void PropagateControlLogic(int type, int flag);
double simtime = 0.; double simtime = 0.;
double idle_t = 0.; double idle_t = 0.;
double seq_t = 0.; double seq_t = 0.;
double send_recv_t = 0.; double send_recv_t = 0.;
double recv_ctrl_t = 0.;
double shuf_t = 0.;
double metrics_t = 0.;
std::array<double, 2> base_totals{0}; std::array<double, 2> base_totals{0};
bool print_progessbar{false}; bool print_progessbar{false};

View File

@ -232,37 +232,6 @@ inline void printProgressbar(int count_pkgs, int n_wp, int barWidth = 70) {
/* end visual progress */ /* end visual progress */
} }
void poet::ChemistryModule::PropagateControlLogic(int type, int flag) {
/*
PropagateFunctionType(type);
static int master_bcast_seq = 0;
int tmp = flag ? 1 : 0;
std::cerr << "[MASTER BCAST " << master_bcast_seq << "] ftype=" << type
<< " flag=" << tmp << std::endl
<< std::flush;
master_bcast_seq++;
ChemBCast(&tmp, 1, MPI_INT);
switch (type) {
case CHEM_CTRL_ENABLE:
this->control_enabled = (tmp == 1);
break;
case CHEM_WARMUP_PHASE:
this->warmup_enabled = (tmp == 1);
break;
case CHEM_DHT_ENABLE:
this->dht_enabled = (tmp == 1);
break;
case CHEM_IP_ENABLE:
this->interp_enabled = (tmp == 1);
break;
default:
break;
}
*/
}
inline void poet::ChemistryModule::MasterSendPkgs( inline void poet::ChemistryModule::MasterSendPkgs(
worker_list_t &w_list, workpointer_t &work_pointer, worker_list_t &w_list, workpointer_t &work_pointer,
workpointer_t &sur_pointer, int &pkg_to_send, int &count_pkgs, workpointer_t &sur_pointer, int &pkg_to_send, int &count_pkgs,
@ -335,6 +304,7 @@ inline void poet::ChemistryModule::MasterRecvPkgs(worker_list_t &w_list,
/* declare most of the variables here */ /* declare most of the variables here */
int need_to_receive = 1; int need_to_receive = 1;
double idle_a, idle_b; double idle_a, idle_b;
double recv_ctrl_a, recv_ctrl_b;
int p, size; int p, size;
std::vector<double> recv_buffer; std::vector<double> recv_buffer;
@ -373,6 +343,7 @@ inline void poet::ChemistryModule::MasterRecvPkgs(worker_list_t &w_list,
break; break;
} }
case LOOP_CTRL: { case LOOP_CTRL: {
recv_ctrl_a = MPI_Wtime();
/* layout of buffer is [phreeqc][surrogate] */ /* layout of buffer is [phreeqc][surrogate] */
MPI_Get_count(&probe_status, MPI_DOUBLE, &size); MPI_Get_count(&probe_status, MPI_DOUBLE, &size);
recv_buffer.resize(size); recv_buffer.resize(size);
@ -385,6 +356,8 @@ inline void poet::ChemistryModule::MasterRecvPkgs(worker_list_t &w_list,
std::copy(recv_buffer.begin() + (size / 2), recv_buffer.begin() + size, std::copy(recv_buffer.begin() + (size / 2), recv_buffer.begin() + size,
w_list[p - 1].surrogate_addr); w_list[p - 1].surrogate_addr);
recv_ctrl_b = MPI_Wtime();
recv_ctrl_t += recv_ctrl_b - recv_ctrl_a;
handled = true; handled = true;
break; break;
@ -450,6 +423,7 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
int free_workers; int free_workers;
int i_pkgs; int i_pkgs;
int ftype; int ftype;
double shuf_a, shuf_b, metrics_a, metrics_b;
const std::vector<uint32_t> wp_sizes_vector = const std::vector<uint32_t> wp_sizes_vector =
CalculateWPSizesVector(this->n_cells, this->wp_size); CalculateWPSizesVector(this->n_cells, this->wp_size);
@ -464,35 +438,6 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
MPI_INT); MPI_INT);
} }
// ftype = CHEM_IP_ENABLE;
// ftype = CHEM_WARMUP_PHASE;
/*
PropagateFunctionType(ftype);
int warmup_flag = this->warmup_enabled ? 1 : 0;
if (warmup_flag) {
this->interp_enabled = false;
int interp_flag = 0;
ChemBCast(&interp_flag, 1, MPI_INT);
// PropagateControlLogic(CHEM_WARMUP_PHASE, 1);
// PropagateControlLogic(CHEM_DHT_ENABLE, 0);
// PropagateControlLogic(CHEM_IP_ENABLE, 0);
} else {
this->interp_enabled = true;
int interp_flag = 1;
ChemBCast(&interp_flag, 1, MPI_INT);
// PropagateControlLogic(CHEM_WARMUP_PHASE, 0);
// PropagateControlLogic(CHEM_DHT_ENABLE, 1);
// PropagateControlLogic(CHEM_IP_ENABLE, 1);
}
int control_flag = this->control_module->GetControlIntervalEnabled() ? 1 : 0;
if (control_flag) {
PropagateControlLogic(CHEM_CTRL_ENABLE, control_flag);
}
*/
ftype = CHEM_WORK_LOOP; ftype = CHEM_WORK_LOOP;
PropagateFunctionType(ftype); PropagateFunctionType(ftype);
@ -509,9 +454,14 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count, shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count,
wp_sizes_vector.size()); wp_sizes_vector.size());
control_enabled = this->control_module->GetControlIntervalEnabled() ? 1 : 0; control_enabled = this->control_module->getControlIntervalEnabled() ? 1 : 0;
std::vector<double> mpi_surr_buffer{mpi_buffer}; std::vector<double> mpi_surr_buffer{mpi_buffer};
std::cout << "control_enabled is " << control_enabled << ", "
<< "warmup_enabled is " << warmup_enabled << ", "
<< "dht_enabled is " << dht_enabled << ", "
<< "interp_enabled is " << interp_enabled << std::endl;
/* setup local variables */ /* setup local variables */
pkg_to_send = wp_sizes_vector.size(); pkg_to_send = wp_sizes_vector.size();
pkg_to_recv = wp_sizes_vector.size(); pkg_to_recv = wp_sizes_vector.size();
@ -569,45 +519,24 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
std::cout << "[Master] Control logic enabled for this iteration." std::cout << "[Master] Control logic enabled for this iteration."
<< std::endl; << std::endl;
std::vector<double> sur_unshuffled{mpi_surr_buffer}; std::vector<double> sur_unshuffled{mpi_surr_buffer};
shuf_a = MPI_Wtime();
unshuffleField(mpi_surr_buffer, this->n_cells, this->prop_count, unshuffleField(mpi_surr_buffer, this->n_cells, this->prop_count,
wp_sizes_vector.size(), sur_unshuffled); wp_sizes_vector.size(), sur_unshuffled);
shuf_b = MPI_Wtime();
this->shuf_t += shuf_b - shuf_a;
// Quick debug: compare out_vec vs sur_unshuffled
size_t N = out_vec.size(); size_t N = out_vec.size();
if (N != sur_unshuffled.size()) { if (N != sur_unshuffled.size()) {
std::cerr << "[MASTER DBG] size mismatch out_vec=" << N std::cerr << "[MASTER DBG] size mismatch out_vec=" << N
<< " sur_unshuffled=" << sur_unshuffled.size() << std::endl; << " sur_unshuffled=" << sur_unshuffled.size() << std::endl;
} /*else {
double max_abs = 0.0;
double max_rel = 0.0;
size_t worst_i = 0;
for (size_t i = 0; i < N; i) {
double a = out_vec[i];
double b = sur_unshuffled[i];
double absd = std::fabs(a - b);
if (absd > max_abs) {
max_abs = absd;
worst_i = i;
} }
double rel = (std::fabs(a) > 1e-12) ? absd / std::fabs(a) : (absd > 0 ?
1e12 : 0.0); if (rel > max_rel) max_rel = rel;
}
std::cerr << "[MASTER DBG] control compare N=" << N
<< " max_abs=" << max_abs << " max_rel=" << max_rel
<< " worst_idx=" << worst_i
<< " out_vec[worst]=" << out_vec[worst_i]
<< " sur[worst]=" << sur_unshuffled[worst_i] << std::endl;
// optionally print first 8 entries
std::cerr << "[MASTER DBG] out[0..7]: ";
for (size_t i = 0; i < std::min<size_t>(8, N); i) std::cerr << out_vec[i]
<< " "; std::cerr << "\n[MASTER DBG] sur[0..7]: "; for (size_t i = 0; i <
std::min<size_t>(8, N); +i) std::cerr << sur_unshuffled[i] << " "; std::cerr
<< std::endl;
}
*/
control_module->ComputeSpeciesErrorMetrics(out_vec, sur_unshuffled, metrics_a = MPI_Wtime();
control_module->computeSpeciesErrorMetrics(out_vec, sur_unshuffled,
this->n_cells); this->n_cells);
metrics_b = MPI_Wtime();
this->metrics_t += metrics_b - metrics_a;
} }
/* start time measurement of master chemistry */ /* start time measurement of master chemistry */

View File

@ -155,6 +155,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
double dht_get_start, dht_get_end; double dht_get_start, dht_get_end;
double phreeqc_time_start, phreeqc_time_end; double phreeqc_time_start, phreeqc_time_end;
double dht_fill_start, dht_fill_end; double dht_fill_start, dht_fill_end;
double ctrl_cp_start, ctrl_cp_end, ctrl_start, ctrl_end;
uint32_t iteration; uint32_t iteration;
double dt; double dt;
@ -239,11 +240,14 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
poet::WorkPackage s_curr_wp_control = s_curr_wp; poet::WorkPackage s_curr_wp_control = s_curr_wp;
if (control_enabled) { if (control_enabled) {
ctrl_cp_start = MPI_Wtime();
for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) { for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) {
s_curr_wp_control.output[wp_i] = s_curr_wp_control.output[wp_i] =
std::vector<double>(this->prop_count, 0.0); std::vector<double>(this->prop_count, 0.0);
s_curr_wp_control.mapping[wp_i] = CHEM_PQC; s_curr_wp_control.mapping[wp_i] = CHEM_PQC;
} }
ctrl_cp_end = MPI_Wtime();
timings.ctrl_t += ctrl_cp_end - ctrl_cp_start;
} }
phreeqc_time_start = MPI_Wtime(); phreeqc_time_start = MPI_Wtime();
@ -254,7 +258,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
phreeqc_time_end = MPI_Wtime(); phreeqc_time_end = MPI_Wtime();
if (control_enabled) { if (control_enabled) {
ctrl_start = MPI_Wtime();
std::size_t sur_wp_offset = s_curr_wp.size * this->prop_count; std::size_t sur_wp_offset = s_curr_wp.size * this->prop_count;
mpi_buffer.resize(count + sur_wp_offset); mpi_buffer.resize(count + sur_wp_offset);
@ -281,7 +285,8 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
} }
} }
count += sur_wp_offset; count += sur_wp_offset;
ctrl_end = MPI_Wtime();
timings.ctrl_t += ctrl_end - ctrl_start;
} else { } else {
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) { for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(), std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(),
@ -302,18 +307,8 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
dht->fillDHT(control_enabled ? s_curr_wp_control : s_curr_wp); dht->fillDHT(control_enabled ? s_curr_wp_control : s_curr_wp);
dht_fill_end = MPI_Wtime(); dht_fill_end = MPI_Wtime();
int filled_count = std::count(dht->getDHTResults().filledDHT.begin(),
dht->getDHTResults().filledDHT.end(), true);
std::cout << "[Worker " << std::to_string(this->comm_rank)
<< "] DHT filled entries=" << std::to_string(filled_count)
<< std::endl;
if (interp_enabled || warmup_enabled) { if (interp_enabled || warmup_enabled) {
interp->writePairs(); interp->writePairs();
std::cout << "[Worker " << std::to_string(this->comm_rank) << "] "
<< "Writing pairs to PHT after iteration "
<< std::to_string(iteration) << std::endl;
} }
timings.dht_fill += dht_fill_end - dht_fill_start; timings.dht_fill += dht_fill_end - dht_fill_start;
} }

View File

@ -4,31 +4,21 @@
#include "IO/StatsIO.hpp" #include "IO/StatsIO.hpp"
#include <cmath> #include <cmath>
void poet::ControlModule::UpdateControlIteration(const uint32_t &iter, void poet::ControlModule::updateControlIteration(const uint32_t &iter,
const bool &dht_enabled, const bool &dht_enabled,
const bool &interp_enabled) { const bool &interp_enabled) {
/* dht_enabled and inter_enabled are user settings set before startig the /* dht_enabled and inter_enabled are user settings set before startig the
* simulation*/ * simulation*/
double prep_a, prep_b;
prep_a = MPI_Wtime();
if (control_interval == 0) { if (control_interval == 0) {
control_interval_enabled = false; control_interval_enabled = false;
return; return;
} }
// InitiateWarmupPhase(dht_enabled, interp_enabled);
global_iteration = iter; global_iteration = iter;
initiateWarmupPhase(dht_enabled, interp_enabled);
if (global_iteration <= control_interval) {
chem->SetWarmupEnabled(true);
chem->SetDhtEnabled(false);
chem->SetInterpEnabled(false);
MSG("Warmup enabled until first control interval at iteration " +
std::to_string(control_interval) + ".");
} else {
chem->SetWarmupEnabled(false);
chem->SetDhtEnabled(true);
chem->SetInterpEnabled(true);
}
control_interval_enabled = control_interval_enabled =
(control_interval > 0 && iter % control_interval == 0); (control_interval > 0 && iter % control_interval == 0);
@ -37,92 +27,82 @@ void poet::ControlModule::UpdateControlIteration(const uint32_t &iter,
MSG("[Control] Control interval enabled at iteration " + MSG("[Control] Control interval enabled at iteration " +
std::to_string(iter)); std::to_string(iter));
} }
prep_b = MPI_Wtime();
this->prep_t += prep_b - prep_a;
} }
void poet::ControlModule::InitiateWarmupPhase(bool dht_enabled, void poet::ControlModule::initiateWarmupPhase(bool dht_enabled,
bool interp_enabled) { bool interp_enabled) {
// user requested DHT/INTEP? keep them disabled but enable warmup-phase so // user requested DHT/INTEP? keep them disabled but enable warmup-phase so
// workers do prepareKeys/fillDHT/writePairs as required. if (global_iteration <= control_interval || rollback_enabled) {
if (global_iteration < control_interval) {
/* warmup phase: keep dht and interp disabled,
workers do prepareKeys/fillDHT/writePairs*/
chem->SetWarmupEnabled(true); chem->SetWarmupEnabled(true);
// chem->SetDhtEnabled(false); chem->SetDhtEnabled(false);
// chem->SetInterpEnabled(false); chem->SetInterpEnabled(false);
MSG("Warmup enabled until first control interval at iteration " + MSG("Warmup enabled until next control interval at iteration " +
std::to_string(control_interval) + "."); std::to_string(control_interval) + ".");
} else {
/* after warmup phase: restore according to user's request*/
chem->SetWarmupEnabled(false);
// chem->SetDhtEnabled(dht_enabled);
// chem->SetInterpEnabled(interp_enabled);
}
}
/*
void poet::ControlModule::beginIteration() {
if (rollback_enabled) { if (rollback_enabled) {
if (sur_disabled_counter > 0) { if (sur_disabled_counter > 0) {
sur_disabled_counter--; --sur_disabled_counter;
MSG("Rollback counter: " + std::to_string(sur_disabled_counter)); MSG("Rollback counter: " + std::to_string(sur_disabled_counter));
} else { } else {
rollback_enabled = false; rollback_enabled = false;
} }
} }
return;
} }
*/
void poet::ControlModule::EndIteration(const uint32_t iter) { chem->SetWarmupEnabled(false);
chem->SetDhtEnabled(dht_enabled);
chem->SetInterpEnabled(interp_enabled);
}
void poet::ControlModule::applyControlLogic(ChemistryModule &chem,
uint32_t &iter) {
if (!control_interval_enabled) { if (!control_interval_enabled) {
return; return;
} }
/* Writing a checkpointing */ writeCheckpointAndMetrics(chem, iter);
/* Control Logic*/
if (!chem) {
MSG("chem pointer is null — skipping checkpoint/stats write");
} else {
MSG("Writing checkpoint of iteration " + std::to_string(iter));
write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
{.field = chem->getField(), .iteration = iter});
writeStatsToCSV(error_history, species_names, out_dir, "stats_overview");
// if() if (checkAndRollback(chem, iter) && rollback_count < 4) {
/*
if (triggerRollbackIfExceeded(*chem, *params, iter)) {
rollback_enabled = true; rollback_enabled = true;
rollback_counter++; rollback_count++;
sur_disabled_counter = control_interval; sur_disabled_counter = control_interval;
MSG("Interpolation disabled for the next " + MSG("Interpolation disabled for the next " +
std::to_string(control_interval) + "."); std::to_string(control_interval) + ".");
} }
*/
}
} }
/* void poet::ControlModule::writeCheckpointAndMetrics(ChemistryModule &chem,
void poet::ControlModule::BCastControlFlags() { uint32_t iter) {
int interp_flag = rollback_enabled ? 0 : 1;
int dht_fill_flag = rollback_enabled ? 1 : 0; double w_check_a, w_check_b, stats_a, stats_b;
chem->ChemBCast(&interp_flag, 1, MPI_INT); MSG("Writing checkpoint of iteration " + std::to_string(iter));
chem->ChemBCast(&dht_fill_flag, 1, MPI_INT);
w_check_a = MPI_Wtime();
write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
{.field = chem.getField(), .iteration = iter});
w_check_b = MPI_Wtime();
this->w_check_t += w_check_b - w_check_a;
stats_a = MPI_Wtime();
writeStatsToCSV(metricsHistory, species_names, out_dir, "stats_overview");
stats_b = MPI_Wtime();
this->stats_t += stats_b - stats_a;
} }
*/ bool poet::ControlModule::checkAndRollback(ChemistryModule &chem,
uint32_t &iter) {
double r_check_a, r_check_b;
if (metricsHistory.empty()) {
bool poet::ControlModule::RollbackIfThresholdExceeded(ChemistryModule &chem) {
/**
if (error_history.empty()) {
MSG("No error history yet; skipping rollback check."); MSG("No error history yet; skipping rollback check.");
return false; return false;
} }
const auto &mape = error_history.back().mape; const auto &mape = metricsHistory.back().mape;
for (uint32_t i = 0; i < species_names.size(); ++i) { for (uint32_t i = 0; i < species_names.size(); ++i) {
if (mape[i] == 0) { if (mape[i] == 0) {
@ -130,34 +110,36 @@ bool poet::ControlModule::RollbackIfThresholdExceeded(ChemistryModule &chem) {
} }
if (mape[i] > mape_threshold[i]) { if (mape[i] > mape_threshold[i]) {
uint32_t rollback_iter = ((global_iteration - 1) / checkpoint_interval) * uint32_t rollback_iter =
checkpoint_interval; ((iter - 1) / checkpoint_interval) * checkpoint_interval;
MSG("[THRESHOLD EXCEEDED] " + species_names[i] + MSG("[THRESHOLD EXCEEDED] " + species_names[i] +
" has MAPE = " + std::to_string(mape[i]) + " has MAPE = " + std::to_string(mape[i]) +
" exceeding threshold = " + std::to_string(mape_threshold[i]) " exceeding threshold = " + std::to_string(mape_threshold[i]) +
+ " → rolling back to iteration " + std::to_string(rollback_iter)); " → rolling back to iteration " + std::to_string(rollback_iter));
r_check_a = MPI_Wtime();
Checkpoint_s checkpoint_read{.field = chem.getField()}; Checkpoint_s checkpoint_read{.field = chem.getField()};
read_checkpoint(out_dir, read_checkpoint(out_dir,
"checkpoint" + std::to_string(rollback_iter) + ".hdf5", "checkpoint" + std::to_string(rollback_iter) + ".hdf5",
checkpoint_read); checkpoint_read);
global_iteration = checkpoint_read.iteration; iter = checkpoint_read.iteration;
r_check_b = MPI_Wtime();
r_check_t += r_check_b - r_check_a;
return true; return true;
} }
} }
MSG("All species are within their MAPE thresholds."); MSG("All species are within their MAPE thresholds.");
return false; return false;
*/
} }
void poet::ControlModule::ComputeSpeciesErrorMetrics( void poet::ControlModule::computeSpeciesErrorMetrics(
const std::vector<double> &reference_values, const std::vector<double> &reference_values,
const std::vector<double> &surrogate_values, const uint32_t size_per_prop) { const std::vector<double> &surrogate_values, const uint32_t size_per_prop) {
SimulationErrorStats species_error_stats(this->species_names.size(), SpeciesErrorMetrics metrics(this->species_names.size(), global_iteration,
global_iteration, rollback_count);
/*rollback_counter*/ 0);
if (reference_values.size() != surrogate_values.size()) { if (reference_values.size() != surrogate_values.size()) {
MSG(" Reference and surrogate vectors differ in size: " + MSG(" Reference and surrogate vectors differ in size: " +
@ -179,8 +161,8 @@ void poet::ControlModule::ComputeSpeciesErrorMetrics(
double err_sum = 0.0; double err_sum = 0.0;
double sqr_err_sum = 0.0; double sqr_err_sum = 0.0;
uint32_t base_idx = i * size_per_prop; uint32_t base_idx = i * size_per_prop;
uint32_t nan_count = 0;
uint32_t valid_count = 0; int count = 0;
for (uint32_t j = 0; j < size_per_prop; ++j) { for (uint32_t j = 0; j < size_per_prop; ++j) {
const double ref_value = reference_values[base_idx + j]; const double ref_value = reference_values[base_idx + j];
@ -188,10 +170,8 @@ void poet::ControlModule::ComputeSpeciesErrorMetrics(
const double ZERO_ABS = 1e-13; const double ZERO_ABS = 1e-13;
if (std::isnan(ref_value) || std::isnan(sur_value)) { if (std::isnan(ref_value) || std::isnan(sur_value)) {
nan_count++;
continue; continue;
} }
valid_count++;
if (std::abs(ref_value) < ZERO_ABS) { if (std::abs(ref_value) < ZERO_ABS) {
if (std::abs(sur_value) >= ZERO_ABS) { if (std::abs(sur_value) >= ZERO_ABS) {
@ -206,15 +186,8 @@ void poet::ControlModule::ComputeSpeciesErrorMetrics(
sqr_err_sum += alpha * alpha; sqr_err_sum += alpha * alpha;
} }
} }
if (valid_count > 0) { metrics.mape[i] = 100.0 * (err_sum / size_per_prop);
species_error_stats.mape[i] = 100.0 * (err_sum / valid_count); metrics.rrmse[i] = std::sqrt(sqr_err_sum / size_per_prop);
species_error_stats.rrmse[i] = std::sqrt(sqr_err_sum / valid_count);
} else {
species_error_stats.mape[i] = 0.0;
species_error_stats.rrmse[i] = 0.0;
std::cerr << "[CTRL WARN] no valid samples for species " << i << " ("
<< this->species_names[i] << "), setting errors to 0\n";
} }
} metricsHistory.push_back(metrics);
error_history.push_back(species_error_stats);
} }

View File

@ -22,36 +22,29 @@ public:
// std::uint32_t sur_disabled_counter = 0; // std::uint32_t sur_disabled_counter = 0;
// std::uint32_t rollback_counter = 0; // std::uint32_t rollback_counter = 0;
void UpdateControlIteration(const uint32_t &iter, const bool &dht_enabled, void updateControlIteration(const uint32_t &iter, const bool &dht_enabled,
const bool &interp_enaled); const bool &interp_enaled);
void InitiateWarmupPhase(bool dht_enabled, bool interp_enabled); void initiateWarmupPhase(bool dht_enabled, bool interp_enabled);
auto GetGlobalIteration() const noexcept { return global_iteration; } bool checkAndRollback(ChemistryModule &chem, uint32_t &iter);
// void beginIteration(); struct SpeciesErrorMetrics {
// void BCastControlFlags();
bool RollbackIfThresholdExceeded(ChemistryModule &chem);
struct SimulationErrorStats {
std::vector<double> mape; std::vector<double> mape;
std::vector<double> rrmse; std::vector<double> rrmse;
uint32_t iteration; // iterations in simulation after rollbacks uint32_t iteration; // iterations in simulation after rollbacks
uint32_t rollback_count; uint32_t rollback_count;
SimulationErrorStats(uint32_t species_count, uint32_t iter, SpeciesErrorMetrics(uint32_t species_count, uint32_t iter, uint32_t counter)
uint32_t counter)
: mape(species_count, 0.0), rrmse(species_count, 0.0), iteration(iter), : mape(species_count, 0.0), rrmse(species_count, 0.0), iteration(iter),
rollback_count(counter) {} rollback_count(counter) {}
}; };
void ComputeSpeciesErrorMetrics(const std::vector<double> &reference_values, void computeSpeciesErrorMetrics(const std::vector<double> &reference_values,
const std::vector<double> &surrogate_values, const std::vector<double> &surrogate_values,
const uint32_t size_per_prop); const uint32_t size_per_prop);
std::vector<SimulationErrorStats> error_history; std::vector<SpeciesErrorMetrics> metricsHistory;
struct ControlSetup { struct ControlSetup {
std::string out_dir; std::string out_dir;
@ -61,7 +54,7 @@ public:
std::vector<double> mape_threshold; std::vector<double> mape_threshold;
}; };
void EnableControlLogic(const ControlSetup &setup) { void enableControlLogic(const ControlSetup &setup) {
this->out_dir = setup.out_dir; this->out_dir = setup.out_dir;
this->checkpoint_interval = setup.checkpoint_interval; this->checkpoint_interval = setup.checkpoint_interval;
this->control_interval = setup.control_interval; this->control_interval = setup.control_interval;
@ -69,24 +62,31 @@ public:
this->mape_threshold = setup.mape_threshold; this->mape_threshold = setup.mape_threshold;
} }
bool GetControlIntervalEnabled() const { bool getControlIntervalEnabled() const {
return this->control_interval_enabled; return this->control_interval_enabled;
} }
void EndIteration(const uint32_t iter); void applyControlLogic(ChemistryModule &chem, uint32_t &iter);
void SetChemistryModule(poet::ChemistryModule *c) { chem = c; } void writeCheckpointAndMetrics(ChemistryModule &chem, uint32_t iter);
auto GetControlInterval() const { return this->control_interval; } auto getGlobalIteration() const noexcept { return global_iteration; }
std::vector<double> GetMapeThreshold() const { return this->mape_threshold; } void setChemistryModule(poet::ChemistryModule *c) { chem = c; }
auto getControlInterval() const { return this->control_interval; }
std::vector<double> getMapeThreshold() const { return this->mape_threshold; }
/* Profiling getters */ /* Profiling getters */
auto GetMasterCtrlLogicTime() const { return this->ctrl_time; }
auto GetMasterCtrlBcastTime() const { return this->bcast_ctrl_time; } auto getUpdateCtrlLogicTime() const { return this->prep_t; }
auto GetMasterRecvCtrlLogicTime() const { return this->recv_ctrl_time; } auto getWriteCheckpointTime() const { return this->w_check_t; }
auto getReadCheckpointTime() const { return this->r_check_t; }
auto getWriteMetricsTime() const { return this->stats_t; }
private: private:
bool rollback_enabled = false; bool rollback_enabled = false;
@ -97,14 +97,17 @@ private:
std::uint32_t checkpoint_interval = 0; std::uint32_t checkpoint_interval = 0;
std::uint32_t control_interval = 0; std::uint32_t control_interval = 0;
std::uint32_t global_iteration = 0; std::uint32_t global_iteration = 0;
std::uint32_t rollback_count = 0;
std::uint32_t sur_disabled_counter = 0;
std::vector<double> mape_threshold; std::vector<double> mape_threshold;
std::vector<std::string> species_names; std::vector<std::string> species_names;
std::string out_dir; std::string out_dir;
double ctrl_time = 0.0; double prep_t = 0.;
double bcast_ctrl_time = 0.0; double r_check_t = 0.;
double recv_ctrl_time = 0.0; double w_check_t = 0.;
double stats_t = 0.;
/* Buffer for shuffled surrogate data */ /* Buffer for shuffled surrogate data */
std::vector<double> sur_shuffled; std::vector<double> sur_shuffled;

View File

@ -7,7 +7,7 @@
namespace poet namespace poet
{ {
void writeStatsToCSV(const std::vector<ControlModule::SimulationErrorStats> &all_stats, void writeStatsToCSV(const std::vector<ControlModule::SpeciesErrorMetrics> &all_stats,
const std::vector<std::string> &species_names, const std::vector<std::string> &species_names,
const std::string &out_dir, const std::string &out_dir,
const std::string &filename) const std::string &filename)
@ -47,7 +47,7 @@ namespace poet
} }
out.close(); out.close();
std::cout << "Stats written to " << filename << "\n"; std::cout << "Error metrics written to " << out_dir << "/" << filename << "\n";
} }
} }
// namespace poet // namespace poet

View File

@ -3,7 +3,7 @@
namespace poet namespace poet
{ {
void writeStatsToCSV(const std::vector<ControlModule::SimulationErrorStats> &all_stats, void writeStatsToCSV(const std::vector<ControlModule::SpeciesErrorMetrics> &all_stats,
const std::vector<std::string> &species_names, const std::vector<std::string> &species_names,
const std::string &out_dir, const std::string &out_dir,
const std::string &filename); const std::string &filename);

View File

@ -300,23 +300,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
double dSimTime{0}; double dSimTime{0};
for (uint32_t iter = 1; iter < maxiter + 1; iter++) { for (uint32_t iter = 1; iter < maxiter + 1; iter++) {
// Rollback countdowm control.updateControlIteration(iter, params.use_dht, params.use_interp);
/*
if (params.rollback_enabled) {
if (params.sur_disabled_counter > 0) {
--params.sur_disabled_counter;
MSG("Rollback counter: " + std::to_string(params.sur_disabled_counter));
} else {
params.rollback_enabled = false;
}
}
*/
//control.beginIteration(iter);
// params.global_iter = iter;
control.UpdateControlIteration(iter, params.use_dht, params.use_interp);
// params.control_interval_enabled = (iter % params.control_interval == 0);
double start_t = MPI_Wtime(); double start_t = MPI_Wtime();
@ -333,8 +317,6 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
/* run transport */ /* run transport */
diffusion.simulate(dt); diffusion.simulate(dt);
// chem.runtime_params = &params;
chem.getField().update(diffusion.getField()); chem.getField().update(diffusion.getField());
// MSG("Chemistry start"); // MSG("Chemistry start");
@ -428,33 +410,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" + MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +
std::to_string(maxiter)); std::to_string(maxiter));
control.EndIteration(iter); control.applyControlLogic(chem, iter);
/*
if (iter % params.checkpoint_interval == 0) {
MSG("Writing checkpoint of iteration " + std::to_string(iter));
write_checkpoint(params.out_dir,
"checkpoint" + std::to_string(iter) + ".hdf5",
{.field = chem.getField(), .iteration = iter});
}
if (params.control_interval_enabled && !params.rollback_enabled) {
writeStatsToCSV(chem.error_history, chem.getField().GetProps(),
params.out_dir, "stats_overview");
if (triggerRollbackIfExceeded(chem, params, iter)) {
params.rollback_enabled = true;
params.rollback_counter++;
params.sur_disabled_counter = params.control_interval;
MSG("Interpolation disabled for the next " +
std::to_string(params.control_interval) + ".");
}
}
*/
// MSG(); // MSG();
} // END SIMULATION LOOP } // END SIMULATION LOOP
@ -471,14 +427,17 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
Rcpp::List diffusion_profiling; Rcpp::List diffusion_profiling;
diffusion_profiling["simtime"] = diffusion.getTransportTime(); diffusion_profiling["simtime"] = diffusion.getTransportTime();
/*Rcpp::List ctrl_profiling; Rcpp::List ctrl_profiling;
ctrl_profiling["checkpointing_time"] = chkTime; ctrl_profiling["compute_metrics_master"] = chem.GetMasterCtrlMetricsTime();
ctrl_profiling["ctrl_logic_master"] = chem.GetMasterCtrlLogicTime(); ctrl_profiling["unshuffle_field_master"] = chem.GetMasterUnshuffleTime();
ctrl_profiling["bcast_ctrl_logic_master"] = chem.GetMasterCtrlBcastTime(); ctrl_profiling["w_checkpoint_master"] = control.getWriteCheckpointTime();
ctrl_profiling["recv_ctrl_logic_maser"] = chem.GetMasterRecvCtrlLogicTime(); ctrl_profiling["r_checkpoint_master"] = control.getReadCheckpointTime();
ctrl_profiling["ctrl_logic_worker"] = ctrl_profiling["write_stats"] = control.getWriteMetricsTime();
ctrl_profiling["ctrl_logic_master"] = control.getUpdateCtrlLogicTime();
ctrl_profiling["recv_data_master"] = chem.GetMasterRecvCtrlDataTime();
ctrl_profiling["worker"] =
Rcpp::wrap(chem.GetWorkerControlTimings()); Rcpp::wrap(chem.GetWorkerControlTimings());
*/
if (params.use_dht) { if (params.use_dht) {
chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits()); chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits());
@ -506,7 +465,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
profiling["simtime"] = dSimTime; profiling["simtime"] = dSimTime;
profiling["chemistry"] = chem_profiling; profiling["chemistry"] = chem_profiling;
profiling["diffusion"] = diffusion_profiling; profiling["diffusion"] = diffusion_profiling;
//profiling["ctrl_logic"] = ctrl_profiling; profiling["control_loop"] = ctrl_profiling;
chem.MasterLoopBreak(); chem.MasterLoopBreak();
@ -651,7 +610,7 @@ int main(int argc, char *argv[]) {
ControlModule control; ControlModule control;
chemistry.SetControlModule(&control); chemistry.SetControlModule(&control);
control.SetChemistryModule(&chemistry); control.setChemistryModule(&chemistry);
const ChemistryModule::SurrogateSetup surr_setup = { const ChemistryModule::SurrogateSetup surr_setup = {
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
@ -676,7 +635,7 @@ int main(int argc, char *argv[]) {
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
run_params.mape_threshold}; run_params.mape_threshold};
control.EnableControlLogic(ctrl_setup); control.enableControlLogic(ctrl_setup);
if (MY_RANK > 0) { if (MY_RANK > 0) {
chemistry.WorkerLoop(); chemistry.WorkerLoop();

View File

@ -3,9 +3,9 @@
#include <doctest/doctest.h> #include <doctest/doctest.h>
#include <vector> #include <vector>
#include <Chemistry/ChemistryModule.hpp> #include <Control/ControlModule.hpp>
TEST_CASE("Stats calculation") TEST_CASE("Metrics calculation")
{ {
std::vector<double> real = std::vector<double> real =
{ {
@ -27,15 +27,15 @@ TEST_CASE("Stats calculation")
2.8, 0.02, 0.7, 0.5 2.8, 0.02, 0.7, 0.5
}; };
poet::ChemistryModule::error_stats stats(6, 5); poet::ControlModule::metricsHistory metrics(6, 5);
poet::ChemistryModule::computeStats(real, pred, /*size_per_prop*/ 4, /*species_count*/ 6, stats); poet::ControlModule::computeSpeciesErrorMetrics(real, pred, /*size_per_prop*/ 4);
SUBCASE("Non-zero values") SUBCASE("Non-zero values")
{ {
// species 1 is ID, should stay 0 // species 1 is ID, should stay 0
CHECK_EQ(stats.mape[0], 0); CHECK_EQ(metrics.mape[0], 0);
CHECK_EQ(stats.rrsme[0], 0); CHECK_EQ(metrics.rrsme[0], 0);
/* /*
mape species 2 mape species 2
@ -49,8 +49,8 @@ TEST_CASE("Stats calculation")
rrsme = sqrt(1.02040816/4) = 0.50507627 rrsme = sqrt(1.02040816/4) = 0.50507627
*/ */
CHECK_EQ(stats.mape[1], doctest::Approx(28.5714286).epsilon(1e-6)); CHECK_EQ(metrics.mape[1], doctest::Approx(28.5714286).epsilon(1e-6));
CHECK_EQ(stats.rrsme[1], doctest::Approx(0.50507627).epsilon(1e-6)); CHECK_EQ(metrics.rrsme[1], doctest::Approx(0.50507627).epsilon(1e-6));
} }
SUBCASE("Zero-denominator case") SUBCASE("Zero-denominator case")
@ -65,8 +65,8 @@ TEST_CASE("Stats calculation")
rrsme = 1 rrsme = 1
*/ */
CHECK_EQ(stats.mape[2], 100.0); CHECK_EQ(metrics.mape[2], 100.0);
CHECK_EQ(stats.rrsme[2], 1.0); CHECK_EQ(metrics.rrsme[2], 1.0);
} }
SUBCASE("True and predicted values are zero") SUBCASE("True and predicted values are zero")
@ -81,8 +81,8 @@ TEST_CASE("Stats calculation")
rrsme = 0.0 rrsme = 0.0
*/ */
CHECK_EQ(stats.mape[3], 0.0); CHECK_EQ(metrics.mape[3], 0.0);
CHECK_EQ(stats.rrsme[3], 0.0); CHECK_EQ(metrics.rrsme[3], 0.0);
} }
SUBCASE("Negative values") SUBCASE("Negative values")
@ -97,8 +97,8 @@ TEST_CASE("Stats calculation")
rrsme = sqrt(13.6989796 / 4) = 1.85060663 rrsme = sqrt(13.6989796 / 4) = 1.85060663
*/ */
CHECK_EQ(stats.mape[4], doctest::Approx(183.92857143).epsilon(1e-6)); CHECK_EQ(metrics.mape[4], doctest::Approx(183.92857143).epsilon(1e-6));
CHECK_EQ(stats.rrsme[4], doctest::Approx(1.85060663).epsilon(1e-6)); CHECK_EQ(metrics.rrsme[4], doctest::Approx(1.85060663).epsilon(1e-6));
} }
SUBCASE("Large differences") SUBCASE("Large differences")
@ -113,7 +113,7 @@ TEST_CASE("Stats calculation")
rrsme = sqrt(2,12262382 / 4) = 0.72846136 rrsme = sqrt(2,12262382 / 4) = 0.72846136
*/ */
CHECK_EQ(stats.mape[5], doctest::Approx(62.102492).epsilon(1e-6)); CHECK_EQ(metrics.mape[5], doctest::Approx(62.102492).epsilon(1e-6));
CHECK_EQ(stats.rrsme[5], doctest::Approx(0.72846136).epsilon(1e-6)); CHECK_EQ(metrics.rrsme[5], doctest::Approx(0.72846136).epsilon(1e-6));
} }
} }