Store per-cell MAPE/RRMSE and add stabilization interval

This commit is contained in:
rastogi 2025-11-12 10:20:28 +01:00
parent 6fbeaed12d
commit 28bd24f9f6
10 changed files with 130 additions and 80 deletions

4
.gitignore vendored
View File

@ -152,13 +152,13 @@ lib/
include/ include/
# But keep these specific files # But keep these specific files
!bin/compare_qs2.R !bin/plot/
!bin/dolo/
!bin/barite_fgcs_3.pqi !bin/barite_fgcs_3.pqi
!bin/barite_fgcs_4_rt.R !bin/barite_fgcs_4_rt.R
!bin/barite_fgcs_4.R !bin/barite_fgcs_4.R
!bin/barite_fgcs_4.qs2 !bin/barite_fgcs_4.qs2
!bin/db_barite.dat !bin/db_barite.dat
!bin/dol.pqi
!bin/dolo_fgcs_3.qs2 !bin/dolo_fgcs_3.qs2
!bin/dolo_fgcs_3.R !bin/dolo_fgcs_3.R
!bin/dolo_fgcs_3.pqi !bin/dolo_fgcs_3.pqi

Binary file not shown.

View File

@ -537,7 +537,7 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
metrics_a = MPI_Wtime(); metrics_a = MPI_Wtime();
control_module->computeSpeciesErrorMetrics(this->control_batch, control_module->computeSpeciesErrorMetrics(this->control_batch,
surrogate_batch, 1); surrogate_batch);
metrics_b = MPI_Wtime(); metrics_b = MPI_Wtime();
this->metrics_t += metrics_b - metrics_a; this->metrics_t += metrics_b - metrics_a;

View File

@ -200,8 +200,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
this->warmup_enabled = (flags & 4) != 0; this->warmup_enabled = (flags & 4) != 0;
this->control_enabled = (flags & 8) != 0; this->control_enabled = (flags & 8) != 0;
/*std::cout << "warmup_enabled is " << warmup_enabled << ", control_enabled is /*std::cout << "warmup_enabled is " << warmup_enabled << ", control_enabled is "
"
<< control_enabled << ", dht_enabled is " << control_enabled << ", dht_enabled is "
<< dht_enabled << ", interp_enabled is " << interp_enabled << dht_enabled << ", interp_enabled is " << interp_enabled
<< std::endl;*/ << std::endl;*/

View File

@ -14,9 +14,25 @@ void poet::ControlModule::updateControlIteration(const uint32_t &iter,
prep_a = MPI_Wtime(); prep_a = MPI_Wtime();
/*
if (control_interval == 0) {
control_interval_enabled = false;
return;
}
*/
global_iteration = iter; global_iteration = iter;
initiateWarmupPhase(dht_enabled, interp_enabled); initiateWarmupPhase(dht_enabled, interp_enabled);
/*
control_interval_enabled =
(control_interval > 0 && iter % control_interval == 0);
if (control_interval_enabled) {
MSG("[Control] Control interval enabled at iteration " +
std::to_string(iter));
}
*/
prep_b = MPI_Wtime(); prep_b = MPI_Wtime();
this->prep_t += prep_b - prep_a; this->prep_t += prep_b - prep_a;
} }
@ -25,13 +41,13 @@ void poet::ControlModule::initiateWarmupPhase(bool dht_enabled,
bool interp_enabled) { bool interp_enabled) {
// user requested DHT/INTEP? keep them disabled but enable warmup-phase so // user requested DHT/INTEP? keep them disabled but enable warmup-phase so
if (rollback_enabled) { if (global_iteration < stabilization_interval || rollback_enabled) {
chem->SetWarmupEnabled(true); chem->SetWarmupEnabled(true);
chem->SetDhtEnabled(false); chem->SetDhtEnabled(false);
chem->SetInterpEnabled(false); chem->SetInterpEnabled(false);
MSG("Warmup enabled until next control interval at iteration " + MSG("Stabilization enabled until next control interval at iteration " +
std::to_string(penalty_interval) + "."); std::to_string(stabilization_interval) + ".");
if (sur_disabled_counter > 0) { if (sur_disabled_counter > 0) {
--sur_disabled_counter; --sur_disabled_counter;
@ -39,6 +55,7 @@ void poet::ControlModule::initiateWarmupPhase(bool dht_enabled,
} else { } else {
rollback_enabled = false; rollback_enabled = false;
} }
return; return;
} }
@ -50,15 +67,20 @@ void poet::ControlModule::initiateWarmupPhase(bool dht_enabled,
void poet::ControlModule::applyControlLogic(DiffusionModule &diffusion, void poet::ControlModule::applyControlLogic(DiffusionModule &diffusion,
uint32_t &iter) { uint32_t &iter) {
/*
if (!control_interval_enabled) {
return;
}
*/
writeCheckpointAndMetrics(diffusion, iter); writeCheckpointAndMetrics(diffusion, iter);
if (checkAndRollback(diffusion, iter) && rollback_count < 3) { if (checkAndRollback(diffusion, iter)) {
rollback_enabled = true; rollback_enabled = true;
rollback_count++; rollback_count++;
sur_disabled_counter = penalty_interval; sur_disabled_counter = stabilization_interval;
MSG("Interpolation disabled for the next " + MSG("Interpolation disabled for the next " +
std::to_string(penalty_interval) + "."); std::to_string(stabilization_interval) + ".");
} }
} }
@ -85,6 +107,10 @@ bool poet::ControlModule::checkAndRollback(DiffusionModule &diffusion,
uint32_t &iter) { uint32_t &iter) {
double r_check_a, r_check_b; double r_check_a, r_check_b;
if (global_iteration < stabilization_interval) {
return false;
}
if (metricsHistory.empty()) { if (metricsHistory.empty()) {
MSG("No error history yet; skipping rollback check."); MSG("No error history yet; skipping rollback check.");
return false; return false;
@ -92,29 +118,31 @@ bool poet::ControlModule::checkAndRollback(DiffusionModule &diffusion,
const auto &mape = metricsHistory.back().mape; const auto &mape = metricsHistory.back().mape;
for (uint32_t i = 0; i < species_names.size(); ++i) { for (size_t row = 0; row < mape.size(); row++) {
if (mape[i] == 0) { for (size_t col = 0; col < species_names.size() && col < mape[row].size(); col++) {
continue; if (mape[row][col] == 0) {
} continue;
}
if (mape[i] > mape_threshold[i]) { if (mape[row][col] > mape_threshold[col]) {
uint32_t rollback_iter = uint32_t rollback_iter =
((iter - 1) / checkpoint_interval) * checkpoint_interval; ((iter - 1) / checkpoint_interval) * checkpoint_interval;
MSG("[THRESHOLD EXCEEDED] " + species_names[i] + MSG("[THRESHOLD EXCEEDED] " + species_names[col] +
" has MAPE = " + std::to_string(mape[i]) + " has MAPE = " + std::to_string(mape[row][col]) +
" exceeding threshold = " + std::to_string(mape_threshold[i]) + " exceeding threshold = " + std::to_string(mape_threshold[col]) +
" rolling back to iteration " + std::to_string(rollback_iter)); ", rolling back to iteration " + std::to_string(rollback_iter));
r_check_a = MPI_Wtime(); r_check_a = MPI_Wtime();
Checkpoint_s checkpoint_read{.field = diffusion.getField()}; Checkpoint_s checkpoint_read{.field = diffusion.getField()};
read_checkpoint(out_dir, read_checkpoint(out_dir,
"checkpoint" + std::to_string(rollback_iter) + ".hdf5", "checkpoint" + std::to_string(rollback_iter) + ".hdf5",
checkpoint_read); checkpoint_read);
iter = checkpoint_read.iteration; iter = checkpoint_read.iteration;
r_check_b = MPI_Wtime(); r_check_b = MPI_Wtime();
r_check_t += r_check_b - r_check_a; r_check_t += r_check_b - r_check_a;
return true; return true;
}
} }
} }
MSG("All species are within their MAPE thresholds."); MSG("All species are within their MAPE thresholds.");
@ -124,10 +152,15 @@ bool poet::ControlModule::checkAndRollback(DiffusionModule &diffusion,
void poet::ControlModule::computeSpeciesErrorMetrics( void poet::ControlModule::computeSpeciesErrorMetrics(
std::vector<std::vector<double>> &reference_values, std::vector<std::vector<double>> &reference_values,
std::vector<std::vector<double>> &surrogate_values, std::vector<std::vector<double>> &surrogate_values) {
const uint32_t size_per_prop) {
SpeciesErrorMetrics metrics(this->species_names.size(), global_iteration, const uint32_t num_cells = reference_values.size();
const uint32_t species_count = this->species_names.size();
std::cout << "[DEBUG] computeSpeciesErrorMetrics: num_cells=" << num_cells
<< ", species_count=" << species_count << std::endl;
SpeciesErrorMetrics metrics(num_cells, species_count, global_iteration,
rollback_count); rollback_count);
if (reference_values.size() != surrogate_values.size()) { if (reference_values.size() != surrogate_values.size()) {
@ -137,42 +170,38 @@ void poet::ControlModule::computeSpeciesErrorMetrics(
return; return;
} }
for (size_t row = 0; row < reference_values.size(); row++) { for (size_t cell_i = 0; cell_i < num_cells; cell_i++) {
double err_sum = 0.0;
double sqr_err_sum = 0.0;
uint32_t count = 0;
for (size_t col = 0; col < this->species_names.size(); col++) { metrics.id.push_back(reference_values[cell_i][0]);
const double ref_value = reference_values[row][col];
const double sur_value = surrogate_values[row][col]; for (size_t sp_i = 0; sp_i < reference_values[cell_i].size(); sp_i++) {
const double ref_value = reference_values[cell_i][sp_i];
const double sur_value = surrogate_values[cell_i][sp_i];
const double ZERO_ABS = 1e-13; const double ZERO_ABS = 1e-13;
if (std::isnan(ref_value) || std::isnan(sur_value)) { if (std::isnan(ref_value) || std::isnan(sur_value)) {
metrics.mape[cell_i][sp_i] = 0.0;
metrics.rrmse[cell_i][sp_i] = 0.0;
continue; continue;
} }
if (std::abs(ref_value) < ZERO_ABS) { if (std::abs(ref_value) < ZERO_ABS) {
if (std::abs(sur_value) >= ZERO_ABS) { if (std::abs(sur_value) >= ZERO_ABS) {
err_sum += 1.0; metrics.mape[cell_i][sp_i] = 1.0;
sqr_err_sum += 1.0; metrics.rrmse[cell_i][sp_i] = 1.0;
count++; } else {
metrics.mape[cell_i][sp_i] = 0.0;
metrics.rrmse[cell_i][sp_i] = 0.0;
} }
// Both zero: skip
} else { } else {
double alpha = 1.0 - (sur_value / ref_value); double alpha = 1.0 - (sur_value / ref_value);
err_sum += std::abs(alpha); metrics.mape[cell_i][sp_i] = 100.0 * std::abs(alpha);
sqr_err_sum += alpha * alpha; metrics.rrmse[cell_i][sp_i] = alpha * alpha;
count++;
}
// Store metrics for this species after processing all cells
if (count > 0) {
metrics.mape[col] = 100.0 * (err_sum / size_per_prop);
metrics.rrmse[col] = std::sqrt(sqr_err_sum / size_per_prop);
} else {
metrics.mape[col] = 0.0;
metrics.rrmse[col] = 0.0;
} }
} }
metricsHistory.push_back(metrics);
} }
std::cout << "[DEBUG] metrics.id.size()=" << metrics.id.size() << std::endl;
metricsHistory.push_back(metrics);
std::cout << "[DEBUG] metricsHistory.size()=" << metricsHistory.size() << std::endl;
} }

View File

@ -32,26 +32,30 @@ public:
bool checkAndRollback(DiffusionModule &diffusion, uint32_t &iter); bool checkAndRollback(DiffusionModule &diffusion, uint32_t &iter);
struct SpeciesErrorMetrics { struct SpeciesErrorMetrics {
std::vector<double> mape; std::vector<std::uint32_t> id;
std::vector<double> rrmse; std::vector<std::vector<double>> mape;
std::vector<std::vector<double>> rrmse;
uint32_t iteration; // iterations in simulation after rollbacks uint32_t iteration; // iterations in simulation after rollbacks
uint32_t rollback_count; uint32_t rollback_count;
SpeciesErrorMetrics(uint32_t species_count, uint32_t iter, uint32_t counter) SpeciesErrorMetrics(uint32_t num_cells, uint32_t species_count,
: mape(species_count, 0.0), rrmse(species_count, 0.0), iteration(iter), uint32_t iter, uint32_t counter)
rollback_count(counter) {} : mape(num_cells, std::vector<double>(species_count, 0.0)),
rrmse(num_cells, std::vector<double>(species_count, 0.0)),
iteration(iter), rollback_count(counter) {}
}; };
void computeSpeciesErrorMetrics( void computeSpeciesErrorMetrics(
std::vector<std::vector<double>> &reference_values, std::vector<std::vector<double>> &reference_values,
std::vector<std::vector<double>> &surrogate_values, std::vector<std::vector<double>> &surrogate_values);
const uint32_t size_per_prop);
std::vector<SpeciesErrorMetrics> metricsHistory; std::vector<SpeciesErrorMetrics> metricsHistory;
struct ControlSetup { struct ControlSetup {
std::string out_dir; std::string out_dir;
std::uint32_t checkpoint_interval; std::uint32_t checkpoint_interval;
std::uint32_t penalty_interval;
std::uint32_t stabilization_interval;
std::vector<std::string> species_names; std::vector<std::string> species_names;
std::vector<double> mape_threshold; std::vector<double> mape_threshold;
std::vector<uint32_t> ctrl_cell_ids; std::vector<uint32_t> ctrl_cell_ids;
@ -60,6 +64,7 @@ public:
void enableControlLogic(const ControlSetup &setup) { void enableControlLogic(const ControlSetup &setup) {
this->out_dir = setup.out_dir; this->out_dir = setup.out_dir;
this->checkpoint_interval = setup.checkpoint_interval; this->checkpoint_interval = setup.checkpoint_interval;
this->stabilization_interval = setup.stabilization_interval;
this->species_names = setup.species_names; this->species_names = setup.species_names;
this->mape_threshold = setup.mape_threshold; this->mape_threshold = setup.mape_threshold;
this->ctrl_cell_ids = setup.ctrl_cell_ids; this->ctrl_cell_ids = setup.ctrl_cell_ids;
@ -92,7 +97,8 @@ private:
poet::ChemistryModule *chem = nullptr; poet::ChemistryModule *chem = nullptr;
std::uint32_t penalty_interval = 50; std::uint32_t stabilization_interval = 0;
std::uint32_t penalty_interval = 0;
std::uint32_t checkpoint_interval = 0; std::uint32_t checkpoint_interval = 0;
std::uint32_t global_iteration = 0; std::uint32_t global_iteration = 0;
std::uint32_t rollback_count = 0; std::uint32_t rollback_count = 0;

View File

@ -21,27 +21,36 @@ namespace poet
return; return;
} }
// header // header: CellID, Iteration, Rollback, Species, MAPE, RRMSE
out << std::left << std::setw(15) << "Iteration" out << std::left << std::setw(15) << "CellID"
<< std::setw(15) << "Iteration"
<< std::setw(15) << "Rollback" << std::setw(15) << "Rollback"
<< std::setw(15) << "Species" << std::setw(15) << "Species"
<< std::setw(15) << "MAPE" << std::setw(15) << "MAPE"
<< std::setw(15) << "RRSME" << "\n"; << std::setw(15) << "RRMSE" << "\n";
out << std::string(75, '-') << "\n"; out << std::string(90, '-') << "\n";
// data rows // data rows: iterate over iterations
for (size_t i = 0; i < all_stats.size(); ++i) for (size_t iter_idx = 0; iter_idx < all_stats.size(); ++iter_idx)
{ {
for (size_t j = 0; j < species_names.size(); ++j) const auto &metrics = all_stats[iter_idx];
// Iterate over cells
for (size_t cell_idx = 0; cell_idx < metrics.id.size(); ++cell_idx)
{ {
out << std::left // Iterate over species for this cell
<< std::setw(15) << all_stats[i].iteration for (size_t species_idx = 0; species_idx < species_names.size(); ++species_idx)
<< std::setw(15) << all_stats[i].rollback_count {
<< std::setw(15) << species_names[j] out << std::left
<< std::setw(15) << all_stats[i].mape[j] << std::setw(15) << metrics.id[cell_idx]
<< std::setw(15) << all_stats[i].rrmse[j] << std::setw(15) << metrics.iteration
<< "\n"; << std::setw(15) << metrics.rollback_count
<< std::setw(15) << species_names[species_idx]
<< std::setw(15) << metrics.mape[cell_idx][species_idx]
<< std::setw(15) << metrics.rrmse[cell_idx][species_idx]
<< "\n";
}
} }
out << "\n"; out << "\n";
} }

View File

@ -252,6 +252,10 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("timesteps")); Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("timesteps"));
params.checkpoint_interval = params.checkpoint_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval")); Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval"));
params.stabilization_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("stabilization_interval"));
params.penalty_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("penalty_interval"));
params.mape_threshold = Rcpp::as<std::vector<double>>( params.mape_threshold = Rcpp::as<std::vector<double>>(
global_rt_setup->operator[]("mape_threshold")); global_rt_setup->operator[]("mape_threshold"));
params.ctrl_cell_ids = Rcpp::as<std::vector<uint32_t>>( params.ctrl_cell_ids = Rcpp::as<std::vector<uint32_t>>(
@ -645,6 +649,8 @@ int main(int argc, char *argv[]) {
const ControlModule::ControlSetup ctrl_setup = { const ControlModule::ControlSetup ctrl_setup = {
run_params.out_dir, // added run_params.out_dir, // added
run_params.checkpoint_interval, run_params.checkpoint_interval,
run_params.penalty_interval,
run_params.stabilization_interval,
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
run_params.mape_threshold}; run_params.mape_threshold};

View File

@ -50,7 +50,8 @@ struct RuntimeParameters {
std::string out_ext; std::string out_ext;
bool print_progress = false; bool print_progress = false;
std::uint32_t penalty_interval = 0;
std::uint32_t stabilization_interval = 0;
std::uint32_t checkpoint_interval = 0; std::uint32_t checkpoint_interval = 0;
std::uint32_t control_interval = 0; std::uint32_t control_interval = 0;
std::vector<double> mape_threshold; std::vector<double> mape_threshold;