mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-15 12:28:22 +01:00
Store per-cell MAPE/RRMSE and add stabilization interval
This commit is contained in:
parent
6fbeaed12d
commit
28bd24f9f6
4
.gitignore
vendored
4
.gitignore
vendored
@ -152,13 +152,13 @@ lib/
|
||||
include/
|
||||
|
||||
# But keep these specific files
|
||||
!bin/compare_qs2.R
|
||||
!bin/plot/
|
||||
!bin/dolo/
|
||||
!bin/barite_fgcs_3.pqi
|
||||
!bin/barite_fgcs_4_rt.R
|
||||
!bin/barite_fgcs_4.R
|
||||
!bin/barite_fgcs_4.qs2
|
||||
!bin/db_barite.dat
|
||||
!bin/dol.pqi
|
||||
!bin/dolo_fgcs_3.qs2
|
||||
!bin/dolo_fgcs_3.R
|
||||
!bin/dolo_fgcs_3.pqi
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@ -537,7 +537,7 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
||||
|
||||
metrics_a = MPI_Wtime();
|
||||
control_module->computeSpeciesErrorMetrics(this->control_batch,
|
||||
surrogate_batch, 1);
|
||||
surrogate_batch);
|
||||
metrics_b = MPI_Wtime();
|
||||
this->metrics_t += metrics_b - metrics_a;
|
||||
|
||||
|
||||
@ -200,8 +200,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
this->warmup_enabled = (flags & 4) != 0;
|
||||
this->control_enabled = (flags & 8) != 0;
|
||||
|
||||
/*std::cout << "warmup_enabled is " << warmup_enabled << ", control_enabled is
|
||||
"
|
||||
/*std::cout << "warmup_enabled is " << warmup_enabled << ", control_enabled is "
|
||||
<< control_enabled << ", dht_enabled is "
|
||||
<< dht_enabled << ", interp_enabled is " << interp_enabled
|
||||
<< std::endl;*/
|
||||
|
||||
@ -14,9 +14,25 @@ void poet::ControlModule::updateControlIteration(const uint32_t &iter,
|
||||
|
||||
prep_a = MPI_Wtime();
|
||||
|
||||
/*
|
||||
if (control_interval == 0) {
|
||||
control_interval_enabled = false;
|
||||
return;
|
||||
}
|
||||
*/
|
||||
global_iteration = iter;
|
||||
initiateWarmupPhase(dht_enabled, interp_enabled);
|
||||
|
||||
/*
|
||||
control_interval_enabled =
|
||||
(control_interval > 0 && iter % control_interval == 0);
|
||||
|
||||
|
||||
if (control_interval_enabled) {
|
||||
MSG("[Control] Control interval enabled at iteration " +
|
||||
std::to_string(iter));
|
||||
}
|
||||
*/
|
||||
prep_b = MPI_Wtime();
|
||||
this->prep_t += prep_b - prep_a;
|
||||
}
|
||||
@ -25,13 +41,13 @@ void poet::ControlModule::initiateWarmupPhase(bool dht_enabled,
|
||||
bool interp_enabled) {
|
||||
|
||||
// user requested DHT/INTEP? keep them disabled but enable warmup-phase so
|
||||
if (rollback_enabled) {
|
||||
if (global_iteration < stabilization_interval || rollback_enabled) {
|
||||
chem->SetWarmupEnabled(true);
|
||||
chem->SetDhtEnabled(false);
|
||||
chem->SetInterpEnabled(false);
|
||||
|
||||
MSG("Warmup enabled until next control interval at iteration " +
|
||||
std::to_string(penalty_interval) + ".");
|
||||
MSG("Stabilization enabled until next control interval at iteration " +
|
||||
std::to_string(stabilization_interval) + ".");
|
||||
|
||||
if (sur_disabled_counter > 0) {
|
||||
--sur_disabled_counter;
|
||||
@ -39,6 +55,7 @@ void poet::ControlModule::initiateWarmupPhase(bool dht_enabled,
|
||||
} else {
|
||||
rollback_enabled = false;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
@ -50,15 +67,20 @@ void poet::ControlModule::initiateWarmupPhase(bool dht_enabled,
|
||||
void poet::ControlModule::applyControlLogic(DiffusionModule &diffusion,
|
||||
uint32_t &iter) {
|
||||
|
||||
/*
|
||||
if (!control_interval_enabled) {
|
||||
return;
|
||||
}
|
||||
*/
|
||||
writeCheckpointAndMetrics(diffusion, iter);
|
||||
|
||||
if (checkAndRollback(diffusion, iter) && rollback_count < 3) {
|
||||
if (checkAndRollback(diffusion, iter)) {
|
||||
rollback_enabled = true;
|
||||
rollback_count++;
|
||||
sur_disabled_counter = penalty_interval;
|
||||
sur_disabled_counter = stabilization_interval;
|
||||
|
||||
MSG("Interpolation disabled for the next " +
|
||||
std::to_string(penalty_interval) + ".");
|
||||
std::to_string(stabilization_interval) + ".");
|
||||
}
|
||||
}
|
||||
|
||||
@ -85,6 +107,10 @@ bool poet::ControlModule::checkAndRollback(DiffusionModule &diffusion,
|
||||
uint32_t &iter) {
|
||||
double r_check_a, r_check_b;
|
||||
|
||||
if (global_iteration < stabilization_interval) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (metricsHistory.empty()) {
|
||||
MSG("No error history yet; skipping rollback check.");
|
||||
return false;
|
||||
@ -92,29 +118,31 @@ bool poet::ControlModule::checkAndRollback(DiffusionModule &diffusion,
|
||||
|
||||
const auto &mape = metricsHistory.back().mape;
|
||||
|
||||
for (uint32_t i = 0; i < species_names.size(); ++i) {
|
||||
if (mape[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
for (size_t row = 0; row < mape.size(); row++) {
|
||||
for (size_t col = 0; col < species_names.size() && col < mape[row].size(); col++) {
|
||||
if (mape[row][col] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (mape[i] > mape_threshold[i]) {
|
||||
uint32_t rollback_iter =
|
||||
((iter - 1) / checkpoint_interval) * checkpoint_interval;
|
||||
if (mape[row][col] > mape_threshold[col]) {
|
||||
uint32_t rollback_iter =
|
||||
((iter - 1) / checkpoint_interval) * checkpoint_interval;
|
||||
|
||||
MSG("[THRESHOLD EXCEEDED] " + species_names[i] +
|
||||
" has MAPE = " + std::to_string(mape[i]) +
|
||||
" exceeding threshold = " + std::to_string(mape_threshold[i]) +
|
||||
" → rolling back to iteration " + std::to_string(rollback_iter));
|
||||
MSG("[THRESHOLD EXCEEDED] " + species_names[col] +
|
||||
" has MAPE = " + std::to_string(mape[row][col]) +
|
||||
" exceeding threshold = " + std::to_string(mape_threshold[col]) +
|
||||
", rolling back to iteration " + std::to_string(rollback_iter));
|
||||
|
||||
r_check_a = MPI_Wtime();
|
||||
Checkpoint_s checkpoint_read{.field = diffusion.getField()};
|
||||
read_checkpoint(out_dir,
|
||||
"checkpoint" + std::to_string(rollback_iter) + ".hdf5",
|
||||
checkpoint_read);
|
||||
iter = checkpoint_read.iteration;
|
||||
r_check_b = MPI_Wtime();
|
||||
r_check_t += r_check_b - r_check_a;
|
||||
return true;
|
||||
r_check_a = MPI_Wtime();
|
||||
Checkpoint_s checkpoint_read{.field = diffusion.getField()};
|
||||
read_checkpoint(out_dir,
|
||||
"checkpoint" + std::to_string(rollback_iter) + ".hdf5",
|
||||
checkpoint_read);
|
||||
iter = checkpoint_read.iteration;
|
||||
r_check_b = MPI_Wtime();
|
||||
r_check_t += r_check_b - r_check_a;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
MSG("All species are within their MAPE thresholds.");
|
||||
@ -124,10 +152,15 @@ bool poet::ControlModule::checkAndRollback(DiffusionModule &diffusion,
|
||||
|
||||
void poet::ControlModule::computeSpeciesErrorMetrics(
|
||||
std::vector<std::vector<double>> &reference_values,
|
||||
std::vector<std::vector<double>> &surrogate_values,
|
||||
const uint32_t size_per_prop) {
|
||||
std::vector<std::vector<double>> &surrogate_values) {
|
||||
|
||||
SpeciesErrorMetrics metrics(this->species_names.size(), global_iteration,
|
||||
const uint32_t num_cells = reference_values.size();
|
||||
const uint32_t species_count = this->species_names.size();
|
||||
|
||||
std::cout << "[DEBUG] computeSpeciesErrorMetrics: num_cells=" << num_cells
|
||||
<< ", species_count=" << species_count << std::endl;
|
||||
|
||||
SpeciesErrorMetrics metrics(num_cells, species_count, global_iteration,
|
||||
rollback_count);
|
||||
|
||||
if (reference_values.size() != surrogate_values.size()) {
|
||||
@ -137,42 +170,38 @@ void poet::ControlModule::computeSpeciesErrorMetrics(
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t row = 0; row < reference_values.size(); row++) {
|
||||
double err_sum = 0.0;
|
||||
double sqr_err_sum = 0.0;
|
||||
uint32_t count = 0;
|
||||
for (size_t cell_i = 0; cell_i < num_cells; cell_i++) {
|
||||
|
||||
for (size_t col = 0; col < this->species_names.size(); col++) {
|
||||
const double ref_value = reference_values[row][col];
|
||||
const double sur_value = surrogate_values[row][col];
|
||||
metrics.id.push_back(reference_values[cell_i][0]);
|
||||
|
||||
for (size_t sp_i = 0; sp_i < reference_values[cell_i].size(); sp_i++) {
|
||||
const double ref_value = reference_values[cell_i][sp_i];
|
||||
const double sur_value = surrogate_values[cell_i][sp_i];
|
||||
const double ZERO_ABS = 1e-13;
|
||||
|
||||
if (std::isnan(ref_value) || std::isnan(sur_value)) {
|
||||
metrics.mape[cell_i][sp_i] = 0.0;
|
||||
metrics.rrmse[cell_i][sp_i] = 0.0;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (std::abs(ref_value) < ZERO_ABS) {
|
||||
if (std::abs(sur_value) >= ZERO_ABS) {
|
||||
err_sum += 1.0;
|
||||
sqr_err_sum += 1.0;
|
||||
count++;
|
||||
metrics.mape[cell_i][sp_i] = 1.0;
|
||||
metrics.rrmse[cell_i][sp_i] = 1.0;
|
||||
} else {
|
||||
metrics.mape[cell_i][sp_i] = 0.0;
|
||||
metrics.rrmse[cell_i][sp_i] = 0.0;
|
||||
}
|
||||
// Both zero: skip
|
||||
} else {
|
||||
double alpha = 1.0 - (sur_value / ref_value);
|
||||
err_sum += std::abs(alpha);
|
||||
sqr_err_sum += alpha * alpha;
|
||||
count++;
|
||||
}
|
||||
// Store metrics for this species after processing all cells
|
||||
if (count > 0) {
|
||||
metrics.mape[col] = 100.0 * (err_sum / size_per_prop);
|
||||
metrics.rrmse[col] = std::sqrt(sqr_err_sum / size_per_prop);
|
||||
} else {
|
||||
metrics.mape[col] = 0.0;
|
||||
metrics.rrmse[col] = 0.0;
|
||||
metrics.mape[cell_i][sp_i] = 100.0 * std::abs(alpha);
|
||||
metrics.rrmse[cell_i][sp_i] = alpha * alpha;
|
||||
}
|
||||
}
|
||||
metricsHistory.push_back(metrics);
|
||||
}
|
||||
|
||||
std::cout << "[DEBUG] metrics.id.size()=" << metrics.id.size() << std::endl;
|
||||
metricsHistory.push_back(metrics);
|
||||
std::cout << "[DEBUG] metricsHistory.size()=" << metricsHistory.size() << std::endl;
|
||||
}
|
||||
|
||||
@ -32,26 +32,30 @@ public:
|
||||
bool checkAndRollback(DiffusionModule &diffusion, uint32_t &iter);
|
||||
|
||||
struct SpeciesErrorMetrics {
|
||||
std::vector<double> mape;
|
||||
std::vector<double> rrmse;
|
||||
std::vector<std::uint32_t> id;
|
||||
std::vector<std::vector<double>> mape;
|
||||
std::vector<std::vector<double>> rrmse;
|
||||
uint32_t iteration; // iterations in simulation after rollbacks
|
||||
uint32_t rollback_count;
|
||||
|
||||
SpeciesErrorMetrics(uint32_t species_count, uint32_t iter, uint32_t counter)
|
||||
: mape(species_count, 0.0), rrmse(species_count, 0.0), iteration(iter),
|
||||
rollback_count(counter) {}
|
||||
SpeciesErrorMetrics(uint32_t num_cells, uint32_t species_count,
|
||||
uint32_t iter, uint32_t counter)
|
||||
: mape(num_cells, std::vector<double>(species_count, 0.0)),
|
||||
rrmse(num_cells, std::vector<double>(species_count, 0.0)),
|
||||
iteration(iter), rollback_count(counter) {}
|
||||
};
|
||||
|
||||
void computeSpeciesErrorMetrics(
|
||||
std::vector<std::vector<double>> &reference_values,
|
||||
std::vector<std::vector<double>> &surrogate_values,
|
||||
const uint32_t size_per_prop);
|
||||
std::vector<std::vector<double>> &surrogate_values);
|
||||
|
||||
std::vector<SpeciesErrorMetrics> metricsHistory;
|
||||
|
||||
struct ControlSetup {
|
||||
std::string out_dir;
|
||||
std::uint32_t checkpoint_interval;
|
||||
std::uint32_t penalty_interval;
|
||||
std::uint32_t stabilization_interval;
|
||||
std::vector<std::string> species_names;
|
||||
std::vector<double> mape_threshold;
|
||||
std::vector<uint32_t> ctrl_cell_ids;
|
||||
@ -60,6 +64,7 @@ public:
|
||||
void enableControlLogic(const ControlSetup &setup) {
|
||||
this->out_dir = setup.out_dir;
|
||||
this->checkpoint_interval = setup.checkpoint_interval;
|
||||
this->stabilization_interval = setup.stabilization_interval;
|
||||
this->species_names = setup.species_names;
|
||||
this->mape_threshold = setup.mape_threshold;
|
||||
this->ctrl_cell_ids = setup.ctrl_cell_ids;
|
||||
@ -92,7 +97,8 @@ private:
|
||||
|
||||
poet::ChemistryModule *chem = nullptr;
|
||||
|
||||
std::uint32_t penalty_interval = 50;
|
||||
std::uint32_t stabilization_interval = 0;
|
||||
std::uint32_t penalty_interval = 0;
|
||||
std::uint32_t checkpoint_interval = 0;
|
||||
std::uint32_t global_iteration = 0;
|
||||
std::uint32_t rollback_count = 0;
|
||||
|
||||
@ -21,27 +21,36 @@ namespace poet
|
||||
return;
|
||||
}
|
||||
|
||||
// header
|
||||
out << std::left << std::setw(15) << "Iteration"
|
||||
// header: CellID, Iteration, Rollback, Species, MAPE, RRMSE
|
||||
out << std::left << std::setw(15) << "CellID"
|
||||
<< std::setw(15) << "Iteration"
|
||||
<< std::setw(15) << "Rollback"
|
||||
<< std::setw(15) << "Species"
|
||||
<< std::setw(15) << "MAPE"
|
||||
<< std::setw(15) << "RRSME" << "\n";
|
||||
<< std::setw(15) << "RRMSE" << "\n";
|
||||
|
||||
out << std::string(75, '-') << "\n";
|
||||
out << std::string(90, '-') << "\n";
|
||||
|
||||
// data rows
|
||||
for (size_t i = 0; i < all_stats.size(); ++i)
|
||||
// data rows: iterate over iterations
|
||||
for (size_t iter_idx = 0; iter_idx < all_stats.size(); ++iter_idx)
|
||||
{
|
||||
for (size_t j = 0; j < species_names.size(); ++j)
|
||||
const auto &metrics = all_stats[iter_idx];
|
||||
|
||||
// Iterate over cells
|
||||
for (size_t cell_idx = 0; cell_idx < metrics.id.size(); ++cell_idx)
|
||||
{
|
||||
out << std::left
|
||||
<< std::setw(15) << all_stats[i].iteration
|
||||
<< std::setw(15) << all_stats[i].rollback_count
|
||||
<< std::setw(15) << species_names[j]
|
||||
<< std::setw(15) << all_stats[i].mape[j]
|
||||
<< std::setw(15) << all_stats[i].rrmse[j]
|
||||
<< "\n";
|
||||
// Iterate over species for this cell
|
||||
for (size_t species_idx = 0; species_idx < species_names.size(); ++species_idx)
|
||||
{
|
||||
out << std::left
|
||||
<< std::setw(15) << metrics.id[cell_idx]
|
||||
<< std::setw(15) << metrics.iteration
|
||||
<< std::setw(15) << metrics.rollback_count
|
||||
<< std::setw(15) << species_names[species_idx]
|
||||
<< std::setw(15) << metrics.mape[cell_idx][species_idx]
|
||||
<< std::setw(15) << metrics.rrmse[cell_idx][species_idx]
|
||||
<< "\n";
|
||||
}
|
||||
}
|
||||
out << "\n";
|
||||
}
|
||||
|
||||
@ -252,6 +252,10 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) {
|
||||
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("timesteps"));
|
||||
params.checkpoint_interval =
|
||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval"));
|
||||
params.stabilization_interval =
|
||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("stabilization_interval"));
|
||||
params.penalty_interval =
|
||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("penalty_interval"));
|
||||
params.mape_threshold = Rcpp::as<std::vector<double>>(
|
||||
global_rt_setup->operator[]("mape_threshold"));
|
||||
params.ctrl_cell_ids = Rcpp::as<std::vector<uint32_t>>(
|
||||
@ -645,6 +649,8 @@ int main(int argc, char *argv[]) {
|
||||
const ControlModule::ControlSetup ctrl_setup = {
|
||||
run_params.out_dir, // added
|
||||
run_params.checkpoint_interval,
|
||||
run_params.penalty_interval,
|
||||
run_params.stabilization_interval,
|
||||
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
||||
run_params.mape_threshold};
|
||||
|
||||
|
||||
@ -50,7 +50,8 @@ struct RuntimeParameters {
|
||||
std::string out_ext;
|
||||
|
||||
bool print_progress = false;
|
||||
|
||||
std::uint32_t penalty_interval = 0;
|
||||
std::uint32_t stabilization_interval = 0;
|
||||
std::uint32_t checkpoint_interval = 0;
|
||||
std::uint32_t control_interval = 0;
|
||||
std::vector<double> mape_threshold;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user