Store per-cell MAPE/RRMSE and add stabilization interval

This commit is contained in:
rastogi 2025-11-12 10:20:28 +01:00
parent 6fbeaed12d
commit 28bd24f9f6
10 changed files with 130 additions and 80 deletions

4
.gitignore vendored
View File

@ -152,13 +152,13 @@ lib/
include/
# But keep these specific files
!bin/compare_qs2.R
!bin/plot/
!bin/dolo/
!bin/barite_fgcs_3.pqi
!bin/barite_fgcs_4_rt.R
!bin/barite_fgcs_4.R
!bin/barite_fgcs_4.qs2
!bin/db_barite.dat
!bin/dol.pqi
!bin/dolo_fgcs_3.qs2
!bin/dolo_fgcs_3.R
!bin/dolo_fgcs_3.pqi

Binary file not shown.

View File

@ -537,7 +537,7 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
metrics_a = MPI_Wtime();
control_module->computeSpeciesErrorMetrics(this->control_batch,
surrogate_batch, 1);
surrogate_batch);
metrics_b = MPI_Wtime();
this->metrics_t += metrics_b - metrics_a;

View File

@ -200,8 +200,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
this->warmup_enabled = (flags & 4) != 0;
this->control_enabled = (flags & 8) != 0;
/*std::cout << "warmup_enabled is " << warmup_enabled << ", control_enabled is
"
/*std::cout << "warmup_enabled is " << warmup_enabled << ", control_enabled is "
<< control_enabled << ", dht_enabled is "
<< dht_enabled << ", interp_enabled is " << interp_enabled
<< std::endl;*/

View File

@ -14,9 +14,25 @@ void poet::ControlModule::updateControlIteration(const uint32_t &iter,
prep_a = MPI_Wtime();
/*
if (control_interval == 0) {
control_interval_enabled = false;
return;
}
*/
global_iteration = iter;
initiateWarmupPhase(dht_enabled, interp_enabled);
/*
control_interval_enabled =
(control_interval > 0 && iter % control_interval == 0);
if (control_interval_enabled) {
MSG("[Control] Control interval enabled at iteration " +
std::to_string(iter));
}
*/
prep_b = MPI_Wtime();
this->prep_t += prep_b - prep_a;
}
@ -25,13 +41,13 @@ void poet::ControlModule::initiateWarmupPhase(bool dht_enabled,
bool interp_enabled) {
// user requested DHT/INTEP? keep them disabled but enable warmup-phase so
if (rollback_enabled) {
if (global_iteration < stabilization_interval || rollback_enabled) {
chem->SetWarmupEnabled(true);
chem->SetDhtEnabled(false);
chem->SetInterpEnabled(false);
MSG("Warmup enabled until next control interval at iteration " +
std::to_string(penalty_interval) + ".");
MSG("Stabilization enabled until next control interval at iteration " +
std::to_string(stabilization_interval) + ".");
if (sur_disabled_counter > 0) {
--sur_disabled_counter;
@ -39,6 +55,7 @@ void poet::ControlModule::initiateWarmupPhase(bool dht_enabled,
} else {
rollback_enabled = false;
}
return;
}
@ -50,15 +67,20 @@ void poet::ControlModule::initiateWarmupPhase(bool dht_enabled,
void poet::ControlModule::applyControlLogic(DiffusionModule &diffusion,
uint32_t &iter) {
/*
if (!control_interval_enabled) {
return;
}
*/
writeCheckpointAndMetrics(diffusion, iter);
if (checkAndRollback(diffusion, iter) && rollback_count < 3) {
if (checkAndRollback(diffusion, iter)) {
rollback_enabled = true;
rollback_count++;
sur_disabled_counter = penalty_interval;
sur_disabled_counter = stabilization_interval;
MSG("Interpolation disabled for the next " +
std::to_string(penalty_interval) + ".");
std::to_string(stabilization_interval) + ".");
}
}
@ -85,6 +107,10 @@ bool poet::ControlModule::checkAndRollback(DiffusionModule &diffusion,
uint32_t &iter) {
double r_check_a, r_check_b;
if (global_iteration < stabilization_interval) {
return false;
}
if (metricsHistory.empty()) {
MSG("No error history yet; skipping rollback check.");
return false;
@ -92,29 +118,31 @@ bool poet::ControlModule::checkAndRollback(DiffusionModule &diffusion,
const auto &mape = metricsHistory.back().mape;
for (uint32_t i = 0; i < species_names.size(); ++i) {
if (mape[i] == 0) {
continue;
}
for (size_t row = 0; row < mape.size(); row++) {
for (size_t col = 0; col < species_names.size() && col < mape[row].size(); col++) {
if (mape[row][col] == 0) {
continue;
}
if (mape[i] > mape_threshold[i]) {
uint32_t rollback_iter =
((iter - 1) / checkpoint_interval) * checkpoint_interval;
if (mape[row][col] > mape_threshold[col]) {
uint32_t rollback_iter =
((iter - 1) / checkpoint_interval) * checkpoint_interval;
MSG("[THRESHOLD EXCEEDED] " + species_names[i] +
" has MAPE = " + std::to_string(mape[i]) +
" exceeding threshold = " + std::to_string(mape_threshold[i]) +
" rolling back to iteration " + std::to_string(rollback_iter));
MSG("[THRESHOLD EXCEEDED] " + species_names[col] +
" has MAPE = " + std::to_string(mape[row][col]) +
" exceeding threshold = " + std::to_string(mape_threshold[col]) +
", rolling back to iteration " + std::to_string(rollback_iter));
r_check_a = MPI_Wtime();
Checkpoint_s checkpoint_read{.field = diffusion.getField()};
read_checkpoint(out_dir,
"checkpoint" + std::to_string(rollback_iter) + ".hdf5",
checkpoint_read);
iter = checkpoint_read.iteration;
r_check_b = MPI_Wtime();
r_check_t += r_check_b - r_check_a;
return true;
r_check_a = MPI_Wtime();
Checkpoint_s checkpoint_read{.field = diffusion.getField()};
read_checkpoint(out_dir,
"checkpoint" + std::to_string(rollback_iter) + ".hdf5",
checkpoint_read);
iter = checkpoint_read.iteration;
r_check_b = MPI_Wtime();
r_check_t += r_check_b - r_check_a;
return true;
}
}
}
MSG("All species are within their MAPE thresholds.");
@ -124,10 +152,15 @@ bool poet::ControlModule::checkAndRollback(DiffusionModule &diffusion,
void poet::ControlModule::computeSpeciesErrorMetrics(
std::vector<std::vector<double>> &reference_values,
std::vector<std::vector<double>> &surrogate_values,
const uint32_t size_per_prop) {
std::vector<std::vector<double>> &surrogate_values) {
SpeciesErrorMetrics metrics(this->species_names.size(), global_iteration,
const uint32_t num_cells = reference_values.size();
const uint32_t species_count = this->species_names.size();
std::cout << "[DEBUG] computeSpeciesErrorMetrics: num_cells=" << num_cells
<< ", species_count=" << species_count << std::endl;
SpeciesErrorMetrics metrics(num_cells, species_count, global_iteration,
rollback_count);
if (reference_values.size() != surrogate_values.size()) {
@ -137,42 +170,38 @@ void poet::ControlModule::computeSpeciesErrorMetrics(
return;
}
for (size_t row = 0; row < reference_values.size(); row++) {
double err_sum = 0.0;
double sqr_err_sum = 0.0;
uint32_t count = 0;
for (size_t cell_i = 0; cell_i < num_cells; cell_i++) {
for (size_t col = 0; col < this->species_names.size(); col++) {
const double ref_value = reference_values[row][col];
const double sur_value = surrogate_values[row][col];
metrics.id.push_back(reference_values[cell_i][0]);
for (size_t sp_i = 0; sp_i < reference_values[cell_i].size(); sp_i++) {
const double ref_value = reference_values[cell_i][sp_i];
const double sur_value = surrogate_values[cell_i][sp_i];
const double ZERO_ABS = 1e-13;
if (std::isnan(ref_value) || std::isnan(sur_value)) {
metrics.mape[cell_i][sp_i] = 0.0;
metrics.rrmse[cell_i][sp_i] = 0.0;
continue;
}
if (std::abs(ref_value) < ZERO_ABS) {
if (std::abs(sur_value) >= ZERO_ABS) {
err_sum += 1.0;
sqr_err_sum += 1.0;
count++;
metrics.mape[cell_i][sp_i] = 1.0;
metrics.rrmse[cell_i][sp_i] = 1.0;
} else {
metrics.mape[cell_i][sp_i] = 0.0;
metrics.rrmse[cell_i][sp_i] = 0.0;
}
// Both zero: skip
} else {
double alpha = 1.0 - (sur_value / ref_value);
err_sum += std::abs(alpha);
sqr_err_sum += alpha * alpha;
count++;
}
// Store metrics for this species after processing all cells
if (count > 0) {
metrics.mape[col] = 100.0 * (err_sum / size_per_prop);
metrics.rrmse[col] = std::sqrt(sqr_err_sum / size_per_prop);
} else {
metrics.mape[col] = 0.0;
metrics.rrmse[col] = 0.0;
metrics.mape[cell_i][sp_i] = 100.0 * std::abs(alpha);
metrics.rrmse[cell_i][sp_i] = alpha * alpha;
}
}
metricsHistory.push_back(metrics);
}
std::cout << "[DEBUG] metrics.id.size()=" << metrics.id.size() << std::endl;
metricsHistory.push_back(metrics);
std::cout << "[DEBUG] metricsHistory.size()=" << metricsHistory.size() << std::endl;
}

View File

@ -32,26 +32,30 @@ public:
bool checkAndRollback(DiffusionModule &diffusion, uint32_t &iter);
struct SpeciesErrorMetrics {
std::vector<double> mape;
std::vector<double> rrmse;
std::vector<std::uint32_t> id;
std::vector<std::vector<double>> mape;
std::vector<std::vector<double>> rrmse;
uint32_t iteration; // iterations in simulation after rollbacks
uint32_t rollback_count;
SpeciesErrorMetrics(uint32_t species_count, uint32_t iter, uint32_t counter)
: mape(species_count, 0.0), rrmse(species_count, 0.0), iteration(iter),
rollback_count(counter) {}
SpeciesErrorMetrics(uint32_t num_cells, uint32_t species_count,
uint32_t iter, uint32_t counter)
: mape(num_cells, std::vector<double>(species_count, 0.0)),
rrmse(num_cells, std::vector<double>(species_count, 0.0)),
iteration(iter), rollback_count(counter) {}
};
void computeSpeciesErrorMetrics(
std::vector<std::vector<double>> &reference_values,
std::vector<std::vector<double>> &surrogate_values,
const uint32_t size_per_prop);
std::vector<std::vector<double>> &surrogate_values);
std::vector<SpeciesErrorMetrics> metricsHistory;
struct ControlSetup {
std::string out_dir;
std::uint32_t checkpoint_interval;
std::uint32_t penalty_interval;
std::uint32_t stabilization_interval;
std::vector<std::string> species_names;
std::vector<double> mape_threshold;
std::vector<uint32_t> ctrl_cell_ids;
@ -60,6 +64,7 @@ public:
void enableControlLogic(const ControlSetup &setup) {
this->out_dir = setup.out_dir;
this->checkpoint_interval = setup.checkpoint_interval;
this->stabilization_interval = setup.stabilization_interval;
this->species_names = setup.species_names;
this->mape_threshold = setup.mape_threshold;
this->ctrl_cell_ids = setup.ctrl_cell_ids;
@ -92,7 +97,8 @@ private:
poet::ChemistryModule *chem = nullptr;
std::uint32_t penalty_interval = 50;
std::uint32_t stabilization_interval = 0;
std::uint32_t penalty_interval = 0;
std::uint32_t checkpoint_interval = 0;
std::uint32_t global_iteration = 0;
std::uint32_t rollback_count = 0;

View File

@ -21,27 +21,36 @@ namespace poet
return;
}
// header
out << std::left << std::setw(15) << "Iteration"
// header: CellID, Iteration, Rollback, Species, MAPE, RRMSE
out << std::left << std::setw(15) << "CellID"
<< std::setw(15) << "Iteration"
<< std::setw(15) << "Rollback"
<< std::setw(15) << "Species"
<< std::setw(15) << "MAPE"
<< std::setw(15) << "RRSME" << "\n";
<< std::setw(15) << "RRMSE" << "\n";
out << std::string(75, '-') << "\n";
out << std::string(90, '-') << "\n";
// data rows
for (size_t i = 0; i < all_stats.size(); ++i)
// data rows: iterate over iterations
for (size_t iter_idx = 0; iter_idx < all_stats.size(); ++iter_idx)
{
for (size_t j = 0; j < species_names.size(); ++j)
const auto &metrics = all_stats[iter_idx];
// Iterate over cells
for (size_t cell_idx = 0; cell_idx < metrics.id.size(); ++cell_idx)
{
out << std::left
<< std::setw(15) << all_stats[i].iteration
<< std::setw(15) << all_stats[i].rollback_count
<< std::setw(15) << species_names[j]
<< std::setw(15) << all_stats[i].mape[j]
<< std::setw(15) << all_stats[i].rrmse[j]
<< "\n";
// Iterate over species for this cell
for (size_t species_idx = 0; species_idx < species_names.size(); ++species_idx)
{
out << std::left
<< std::setw(15) << metrics.id[cell_idx]
<< std::setw(15) << metrics.iteration
<< std::setw(15) << metrics.rollback_count
<< std::setw(15) << species_names[species_idx]
<< std::setw(15) << metrics.mape[cell_idx][species_idx]
<< std::setw(15) << metrics.rrmse[cell_idx][species_idx]
<< "\n";
}
}
out << "\n";
}

View File

@ -252,6 +252,10 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("timesteps"));
params.checkpoint_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval"));
params.stabilization_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("stabilization_interval"));
params.penalty_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("penalty_interval"));
params.mape_threshold = Rcpp::as<std::vector<double>>(
global_rt_setup->operator[]("mape_threshold"));
params.ctrl_cell_ids = Rcpp::as<std::vector<uint32_t>>(
@ -645,6 +649,8 @@ int main(int argc, char *argv[]) {
const ControlModule::ControlSetup ctrl_setup = {
run_params.out_dir, // added
run_params.checkpoint_interval,
run_params.penalty_interval,
run_params.stabilization_interval,
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
run_params.mape_threshold};

View File

@ -50,7 +50,8 @@ struct RuntimeParameters {
std::string out_ext;
bool print_progress = false;
std::uint32_t penalty_interval = 0;
std::uint32_t stabilization_interval = 0;
std::uint32_t checkpoint_interval = 0;
std::uint32_t control_interval = 0;
std::vector<double> mape_threshold;