added rollback_limit and fixed rollback logic

This commit is contained in:
rastogi 2025-11-28 14:18:06 +01:00
parent ac6bff6b97
commit 1de30ad0db
4 changed files with 94 additions and 68 deletions

View File

@ -1,5 +1,5 @@
#!/bin/bash #!/bin/bash
#SBATCH --job-name=proto2_eps0035_no_rb_v2 #SBATCH --job-name=proto2_eps0035
#SBATCH --output=proto2_eps0035_no_rb_v2_%j.out #SBATCH --output=proto2_eps0035_no_rb_v2_%j.out
#SBATCH --error=proto2_eps0035_no_rb_v2%j.err #SBATCH --error=proto2_eps0035_no_rb_v2%j.err
#SBATCH --partition=long #SBATCH --partition=long
@ -15,5 +15,5 @@ module purge
module load cmake gcc openmpi module load cmake gcc openmpi
#mpirun -n 144 ./poet dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_only_pqc #mpirun -n 144 ./poet dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_only_pqc
mpirun -n 144 ./poet --interp --rds dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 proto2_eps0035_no_rb_v2 mpirun -n 144 ./poet --interp --rds dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 proto2_eps0035
#mpirun -n 144 ./poet --interp barite_fgcs_4_new/barite_fgcs_4_new_rt.R barite_fgcs_4_new/barite_fgcs_4_new.qs2 barite #mpirun -n 144 ./poet --interp barite_fgcs_4_new/barite_fgcs_4_new_rt.R barite_fgcs_4_new/barite_fgcs_4_new.qs2 barite

Binary file not shown.

View File

@ -3,13 +3,12 @@
#include "IO/HDF5Functions.hpp" #include "IO/HDF5Functions.hpp"
#include "IO/StatsIO.hpp" #include "IO/StatsIO.hpp"
#include <cmath> #include <cmath>
#include <numeric>
poet::ControlModule::ControlModule(const ControlConfig &config_) poet::ControlModule::ControlModule(const ControlConfig &config_) : config(config_) {}
: config(config_) {}
void poet::ControlModule::beginIteration(ChemistryModule &chem, uint32_t &iter, void poet::ControlModule::beginIteration(ChemistryModule &chem, uint32_t &iter,
const bool &dht_enabled, const bool &dht_enabled, const bool &interp_enabled) {
const bool &interp_enabled) {
global_iteration = iter; global_iteration = iter;
double prep_a, prep_b; double prep_a, prep_b;
@ -22,35 +21,42 @@ void poet::ControlModule::beginIteration(ChemistryModule &chem, uint32_t &iter,
} }
/* Disables dht and/or interp during stabilzation phase */ /* Disables dht and/or interp during stabilzation phase */
void poet::ControlModule::updateStabilizationPhase(ChemistryModule &chem, void poet::ControlModule::updateStabilizationPhase(ChemistryModule &chem, bool dht_enabled,
bool dht_enabled,
bool interp_enabled) { bool interp_enabled) {
if (rollback_enabled) { bool in_warmup = (global_iteration <= config.stab_interval);
if (disable_surr_counter > 0) { bool rb_limit_reached = (rollback_count >= 3);
--disable_surr_counter;
flush_request = false; if (rollback_enabled && disable_surr_counter > 0) {
std::cout << "Rollback counter: " << std::to_string(disable_surr_counter) --disable_surr_counter;
<< std::endl; std::cout << "Rollback counter: " << std::to_string(disable_surr_counter) << std::endl;
} else { if (disable_surr_counter == 0) {
rollback_enabled = false; rollback_enabled = false;
std::cout << "Rollback stabilization complete, re-enabling surrogate."
<< std::endl;
} }
flush_request = false;
} }
if (global_iteration <= config.stab_interval || rollback_enabled) { /* disable surrogates during warmup, active rollback or after limit */
chem.SetStabEnabled(true); if (in_warmup || rollback_enabled || rb_limit_reached) {
chem.SetStabEnabled(!rb_limit_reached);
chem.SetDhtEnabled(false); chem.SetDhtEnabled(false);
chem.SetInterpEnabled(false); chem.SetInterpEnabled(false);
} else {
chem.SetStabEnabled(false); if (rb_limit_reached) {
chem.SetDhtEnabled(dht_enabled); std::cout << "Interpolation completly disabled." << std::endl;
chem.SetInterpEnabled(interp_enabled); } else {
std::cout << "In stabilization phase." << std::endl;
}
return;
} }
/* enable user-requested surrogates */
chem.SetStabEnabled(false);
chem.SetDhtEnabled(dht_enabled);
chem.SetInterpEnabled(interp_enabled);
std::cout << "Interpolating." << std::endl;
} }
void poet::ControlModule::writeCheckpoint(DiffusionModule &diffusion, void poet::ControlModule::writeCheckpoint(DiffusionModule &diffusion, uint32_t &iter,
uint32_t &iter,
const std::string &out_dir) { const std::string &out_dir) {
if (global_iteration % config.checkpoint_interval == 0) { if (global_iteration % config.checkpoint_interval == 0) {
double w_check_a = MPI_Wtime(); double w_check_a = MPI_Wtime();
@ -62,44 +68,41 @@ void poet::ControlModule::writeCheckpoint(DiffusionModule &diffusion,
} }
} }
void poet::ControlModule::readCheckpoint(DiffusionModule &diffusion, void poet::ControlModule::readCheckpoint(DiffusionModule &diffusion, uint32_t &current_iter,
uint32_t &current_iter, uint32_t rollback_iter, const std::string &out_dir) {
uint32_t rollback_iter,
const std::string &out_dir) {
double r_check_a, r_check_b; double r_check_a, r_check_b;
r_check_a = MPI_Wtime(); r_check_a = MPI_Wtime();
Checkpoint_s checkpoint_read{.field = diffusion.getField()}; Checkpoint_s checkpoint_read{.field = diffusion.getField()};
read_checkpoint(out_dir, read_checkpoint(out_dir, "checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read);
"checkpoint" + std::to_string(rollback_iter) + ".hdf5",
checkpoint_read);
current_iter = checkpoint_read.iteration; current_iter = checkpoint_read.iteration;
r_check_b = MPI_Wtime(); r_check_b = MPI_Wtime();
r_check_t += r_check_b - r_check_a; r_check_t += r_check_b - r_check_a;
} }
void poet::ControlModule::writeErrorMetrics( void poet::ControlModule::writeErrorMetrics(uint32_t &iter, const std::string &out_dir,
uint32_t &iter, const std::string &out_dir, const std::vector<std::string> &species) {
const std::vector<std::string> &species) {
if (rollback_count >= 3) {
return;
}
double stats_a = MPI_Wtime(); double stats_a = MPI_Wtime();
writeSpeciesStatsToCSV(metrics_history, species, out_dir, writeSpeciesStatsToCSV(metrics_history, species, out_dir, "species_overview.csv");
"species_overview.csv"); write_metrics(cell_metrics_history, species, out_dir, "metrics_overview.hdf5");
write_metrics(cell_metrics_history, species, out_dir,
"metrics_overview.hdf5");
double stats_b = MPI_Wtime(); double stats_b = MPI_Wtime();
this->stats_t += stats_b - stats_a; this->stats_t += stats_b - stats_a;
} }
uint32_t poet::ControlModule::getRollbackIter() { uint32_t poet::ControlModule::getRollbackIter() {
uint32_t last_iter = ((global_iteration - 1) / config.checkpoint_interval) * uint32_t last_iter =
config.checkpoint_interval; ((global_iteration - 1) / config.checkpoint_interval) * config.checkpoint_interval;
return last_iter; return last_iter;
} }
std::optional<uint32_t> poet::ControlModule::getRollbackTarget( std::optional<uint32_t>
const std::vector<std::string> &species) { poet::ControlModule::getRollbackTarget(const std::vector<std::string> &species) {
double r_check_a, r_check_b; double r_check_a, r_check_b;
/* Skip threshold checking if already in stabilization phase*/ /* Skip threshold checking if already in stabilization phase*/
@ -114,16 +117,16 @@ std::optional<uint32_t> poet::ControlModule::getRollbackTarget(
/* check bounds of threshold vector*/ /* check bounds of threshold vector*/
if (sp_i >= config.mape_threshold.size()) { if (sp_i >= config.mape_threshold.size()) {
std::cerr << "Warning: No threshold defined for species " << species[sp_i] std::cerr << "Warning: No threshold defined for species " << species[sp_i] << " at index "
<< " at index " << std::to_string(sp_i) << std::endl; << std::to_string(sp_i) << std::endl;
continue; continue;
} }
if (s_hist.mape[sp_i] > config.mape_threshold[sp_i]) { if (s_hist.mape[sp_i] > config.mape_threshold[sp_i]) {
const auto &c_hist = cell_metrics_history.back(); const auto &c_hist = cell_metrics_history.back();
auto max_it = std::max_element( auto max_it =
c_hist.mape.begin(), c_hist.mape.end(), std::max_element(c_hist.mape.begin(), c_hist.mape.end(),
[sp_i](const auto &a, const auto &b) { return a[sp_i] < b[sp_i]; }); [sp_i](const auto &a, const auto &b) { return a[sp_i] < b[sp_i]; });
size_t max_idx = std::distance(c_hist.mape.begin(), max_it); size_t max_idx = std::distance(c_hist.mape.begin(), max_it);
uint32_t cell_id = c_hist.id[max_idx]; uint32_t cell_id = c_hist.id[max_idx];
@ -134,10 +137,8 @@ std::optional<uint32_t> poet::ControlModule::getRollbackTarget(
std::cout << "Threshold exceeded for " << species[sp_i] std::cout << "Threshold exceeded for " << species[sp_i]
<< " with species-level MAPE = " << std::to_string(s_hist.mape[sp_i]) << " with species-level MAPE = " << std::to_string(s_hist.mape[sp_i])
<< " exceeding threshold = " << " exceeding threshold = " << std::to_string(config.mape_threshold[sp_i])
<< std::to_string(config.mape_threshold[sp_i])
<< ". Worst cell: ID=" << std::to_string(cell_id) << ". Worst cell: ID=" << std::to_string(cell_id)
<< ", species= " << species[sp_i]
<< " with MAPE=" << std::to_string(cell_mape) << std::endl; << " with MAPE=" << std::to_string(cell_mape) << std::endl;
return getRollbackIter(); return getRollbackIter();
@ -146,17 +147,20 @@ std::optional<uint32_t> poet::ControlModule::getRollbackTarget(
return std::nullopt; return std::nullopt;
} }
void poet::ControlModule::computeErrorMetrics( void poet::ControlModule::computeErrorMetrics(std::vector<std::vector<double>> &reference_values,
std::vector<std::vector<double>> &reference_values, std::vector<std::vector<double>> &surrogate_values,
std::vector<std::vector<double>> &surrogate_values, const std::vector<std::string> &species,
const std::vector<std::string> &species, const uint32_t size_per_prop) { const uint32_t size_per_prop) {
if (rollback_count >= 3) {
return;
}
const uint32_t n_cells = reference_values.size(); const uint32_t n_cells = reference_values.size();
const uint32_t n_species = species.size(); const uint32_t n_species = species.size();
const double ZERO_ABS = config.zero_abs; const double ZERO_ABS = config.zero_abs;
CellErrorMetrics c_metrics(n_cells, n_species, global_iteration, CellErrorMetrics c_metrics(n_cells, n_species, global_iteration, rollback_count);
rollback_count);
SpeciesErrorMetrics s_metrics(n_species, global_iteration, rollback_count); SpeciesErrorMetrics s_metrics(n_species, global_iteration, rollback_count);
std::vector<double> species_err_sum(n_species, 0.0); std::vector<double> species_err_sum(n_species, 0.0);
@ -194,11 +198,10 @@ void poet::ControlModule::computeErrorMetrics(
c_metrics.rrmse[cell_i][sp_i] = alpha * alpha; c_metrics.rrmse[cell_i][sp_i] = alpha * alpha;
// Log extreme MAPE values for debugging // Log extreme MAPE values for debugging
if (c_metrics.mape[cell_i][sp_i] > 100.0) { if (c_metrics.mape[cell_i][sp_i] > 100.0) {
std::cout << "WARNING: High MAPE detected - Cell=" std::cout << "WARNING: High MAPE detected - Cell=" << c_metrics.id[cell_i]
<< c_metrics.id[cell_i] << ", Species=" << species[sp_i] << ", Species=" << species[sp_i] << ", MAPE=" << c_metrics.mape[cell_i][sp_i]
<< ", MAPE=" << c_metrics.mape[cell_i][sp_i] << "%, Ref=" << ref_value << ", Sur=" << sur_value << ", Alpha=" << alpha
<< "%, Ref=" << ref_value << ", Sur=" << sur_value << std::endl;
<< ", Alpha=" << alpha << std::endl;
} }
} }
} }
@ -208,19 +211,43 @@ void poet::ControlModule::computeErrorMetrics(
s_metrics.rrmse[sp_i] = std::sqrt(species_sqr_sum[sp_i] / size_per_prop); s_metrics.rrmse[sp_i] = std::sqrt(species_sqr_sum[sp_i] / size_per_prop);
} }
// Sort cell metrics by ID
std::vector<size_t> indices(n_cells);
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(),
[&c_metrics](size_t a, size_t b) { return c_metrics.id[a] < c_metrics.id[b]; });
// Reorder cell metrics based on sorted indices
std::vector<uint32_t> sorted_ids(n_cells);
std::vector<std::vector<double>> sorted_mape(n_cells, std::vector<double>(n_species));
std::vector<std::vector<double>> sorted_rrmse(n_cells, std::vector<double>(n_species));
for (size_t i = 0; i < n_cells; i++) {
sorted_ids[i] = c_metrics.id[indices[i]];
sorted_mape[i] = c_metrics.mape[indices[i]];
sorted_rrmse[i] = c_metrics.rrmse[indices[i]];
}
c_metrics.id = std::move(sorted_ids);
c_metrics.mape = std::move(sorted_mape);
c_metrics.rrmse = std::move(sorted_rrmse);
metrics_history.push_back(s_metrics); metrics_history.push_back(s_metrics);
cell_metrics_history.push_back(c_metrics); cell_metrics_history.push_back(c_metrics);
} }
void poet::ControlModule::processCheckpoint( void poet::ControlModule::processCheckpoint(DiffusionModule &diffusion, uint32_t &current_iter,
DiffusionModule &diffusion, uint32_t &current_iter, const std::string &out_dir,
const std::string &out_dir, const std::vector<std::string> &species) { const std::vector<std::string> &species) {
// Use max_rollbacks from config, default to 3 if not set // Use max_rollbacks from config, default to 3 if not set
// uint32_t max_rollbacks = // uint32_t max_rollbacks =
// (config.max_rollbacks > 0) ? config.max_rollbacks : 3; // (config.max_rollbacks > 0) ? config.max_rollbacks : 3;
if (flush_request /* && rollback_count < 3 */) { if (rollback_count >= 3) {
return;
}
if (flush_request && rollback_count < 3) {
uint32_t target = getRollbackIter(); uint32_t target = getRollbackIter();
readCheckpoint(diffusion, current_iter, target, out_dir); readCheckpoint(diffusion, current_iter, target, out_dir);
@ -228,8 +255,7 @@ void poet::ControlModule::processCheckpoint(
rollback_count++; rollback_count++;
disable_surr_counter = config.stab_interval; disable_surr_counter = config.stab_interval;
std::cout << "Restored checkpoint " << std::to_string(target) std::cout << "Restored checkpoint " << std::to_string(target) << ", surrogate disabled for "
<< ", surrogate disabled for "
<< std::to_string(config.stab_interval) << std::endl; << std::to_string(config.stab_interval) << std::endl;
} else { } else {
writeCheckpoint(diffusion, global_iteration, out_dir); writeCheckpoint(diffusion, global_iteration, out_dir);