Add HDF5 metrics tracking and fix rollback stabilization controll

This commit is contained in:
rastogi 2025-11-20 11:31:23 +01:00
parent 39458561ff
commit 40ece6cba3
9 changed files with 185 additions and 134 deletions

Binary file not shown.

View File

@ -443,13 +443,13 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
MPI_INT);
}
if (control->shouldBcastFlags()) {
int ftype = CHEM_CTRL_FLAGS;
PropagateFunctionType(ftype);
uint32_t ctrl_flags = buildCtrlFlags(
this->dht_enabled, this->interp_enabled, this->stab_enabled);
ChemBCast(&ctrl_flags, 1, MPI_INT);
}
// if (control->shouldBcastFlags()) {
ftype = CHEM_CTRL_FLAGS;
PropagateFunctionType(ftype);
uint32_t ctrl_flags = buildCtrlFlags(this->dht_enabled, this->interp_enabled,
this->stab_enabled);
ChemBCast(&ctrl_flags, 1, MPI_INT);
//}
ftype = CHEM_WORK_LOOP;
PropagateFunctionType(ftype);
@ -522,8 +522,8 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
chem_field = out_vec;
/* do master stuff */
std::cout << "[DEBUG] control_batch.size() = "
<< this->control_batch.size() << std::endl;
std::cout << "[DEBUG] control_batch.size() = " << this->control_batch.size()
<< std::endl;
if (!this->control_batch.empty()) {
std::cout << "[Master] Processing " << this->control_batch.size()
@ -551,7 +551,6 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
metrics_a = MPI_Wtime();
control->computeErrorMetrics(this->control_batch, surrogate_batch,
prop_names);
control->writeErrorMetrics(ctrl_file_out_dir, prop_names);
metrics_b = MPI_Wtime();

View File

@ -34,25 +34,25 @@ void poet::ControlModule::updateStabilizationPhase(ChemistryModule &chem,
MSG("Rollback counter: " + std::to_string(disable_surr_counter));
} else {
rollback_enabled = false;
MSG("Rollback stabilization complete, re-enabling surrogates");
}
}
bool prev_stab_state = chem.GetStabEnabled();
// user requested DHT/INTEP? keep them disabled but enable warmup-phase so
if (global_iteration <= config.stab_interval || rollback_enabled) {
chem.SetStabEnabled(true);
chem.SetDhtEnabled(false);
chem.SetInterpEnabled(false);
return;
} else {
chem.SetStabEnabled(false);
chem.SetDhtEnabled(dht_enabled);
chem.SetInterpEnabled(interp_enabled);
}
chem.SetStabEnabled(false);
chem.SetDhtEnabled(dht_enabled);
chem.SetInterpEnabled(interp_enabled);
// Mark that we need to broadcast flags if stab phase just ended
if (prev_stab_state && !chem.GetStabEnabled()) {
// Mark that we need to broadcast flags if stab state changed
if (prev_stab_state != chem.GetStabEnabled()) {
stab_phase_ended = true;
}
}
@ -63,8 +63,10 @@ void poet::ControlModule::writeCheckpoint(DiffusionModule &diffusion,
double w_check_a, w_check_b;
w_check_a = MPI_Wtime();
write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
{.field = diffusion.getField(), .iteration = iter});
if (global_iteration % config.checkpoint_interval == 0) {
write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
{.field = diffusion.getField(), .iteration = iter});
}
w_check_b = MPI_Wtime();
this->w_check_t += w_check_b - w_check_a;
@ -92,7 +94,8 @@ void poet::ControlModule::writeErrorMetrics(
double stats_a, stats_b;
stats_a = MPI_Wtime();
writeStatsToCSV(metrics_history, species, out_dir, "metrics_overview");
writeStatsToCSV(metrics_history, species, out_dir, "overview");
write_metrics(metrics_history, species, out_dir, "metrics_overview");
stats_b = MPI_Wtime();
this->stats_t += stats_b - stats_a;
@ -106,7 +109,13 @@ uint32_t poet::ControlModule::getRollbackIter() {
uint32_t rollback_iter = (last_iter <= last_checkpoint_written)
? last_iter
: last_checkpoint_written;
return rollback_iter;
MSG("getRollbackIter: global_iteration=" + std::to_string(global_iteration) +
", checkpoint_interval=" + std::to_string(config.checkpoint_interval) +
", last_checkpoint_written=" + std::to_string(last_checkpoint_written) +
", returning=" + std::to_string(last_checkpoint_written));
return last_checkpoint_written;
}
std::optional<uint32_t> poet::ControlModule::getRollbackTarget(
@ -116,7 +125,10 @@ std::optional<uint32_t> poet::ControlModule::getRollbackTarget(
if (metrics_history.empty()) {
MSG("No error history yet, skipping rollback check.");
rollback_enabled = false;
// flush_request = false;
return std::nullopt;
}
// Skip threshold checking if already in rollback/stabilization phase
if (rollback_enabled) {
return std::nullopt;
}
@ -147,7 +159,7 @@ std::optional<uint32_t> poet::ControlModule::getRollbackTarget(
}
}
rollback_enabled = false;
// flush_request = false;
flush_request = false;
return std::nullopt;
}
@ -156,6 +168,11 @@ void poet::ControlModule::computeErrorMetrics(
std::vector<std::vector<double>> &surrogate_values,
const std::vector<std::string> &species) {
// Skip metric computation if already in rollback/stabilization phase
if (rollback_enabled) {
return;
}
const uint32_t n_cells = reference_values.size();
SpeciesErrorMetrics metrics(n_cells, species.size(), global_iteration,
@ -166,23 +183,18 @@ void poet::ControlModule::computeErrorMetrics(
metrics.id.push_back(reference_values[cell_i][0]);
for (size_t sp_i = 0; sp_i < species.size(); sp_i++) {
const double ref_value = reference_values[cell_i][sp_i];
const double sur_value = surrogate_values[cell_i][sp_i];
const double ref_value = reference_values[cell_i][sp_i + 1];
const double sur_value = surrogate_values[cell_i][sp_i + 1];
const double ZERO_ABS = config.zero_abs;
if (std::isnan(ref_value) || std::isnan(sur_value)) {
metrics.mape[cell_i][sp_i] = 0.0;
metrics.rrmse[cell_i][sp_i] = 0.0;
continue;
}
if (std::abs(ref_value) < ZERO_ABS) {
if (std::abs(sur_value) >= ZERO_ABS) {
metrics.mape[cell_i][sp_i] = 1.0;
metrics.mape[cell_i][sp_i] = 100.0;
metrics.rrmse[cell_i][sp_i] = 1.0;
} else {
metrics.mape[cell_i][sp_i] = 0.0;
metrics.rrmse[cell_i][sp_i] = 0.0;
}
} else {
double alpha = 1.0 - (sur_value / ref_value);
@ -191,11 +203,7 @@ void poet::ControlModule::computeErrorMetrics(
}
}
}
std::cout << "[DEBUG] metrics.id.size()=" << metrics.id.size() << std::endl;
metrics_history.push_back(metrics);
std::cout << "[DEBUG] metricsHistory.size()=" << metrics_history.size()
<< std::endl;
}
void poet::ControlModule::processCheckpoint(
@ -221,15 +229,10 @@ bool poet::ControlModule::shouldBcastFlags() {
if (global_iteration == 1) {
return true;
}
if (stab_phase_ended) {
stab_phase_ended = false;
return true;
}
if (flush_request) {
return true;
}
return false;
}

View File

@ -4,6 +4,7 @@
#include "Base/Macros.hpp"
#include "Chemistry/ChemistryModule.hpp"
#include "Transport/DiffusionModule.hpp"
#include "IO/HDF5Functions.hpp"
#include "poet.hpp"
#include <cstdint>
@ -37,6 +38,8 @@ struct SpeciesErrorMetrics {
rollback_count(rb_count) {}
};
class ControlModule {
public:

View File

@ -1,9 +1,19 @@
#pragma once
#include <string>
#include "Datatypes.hpp"
#include <string>
#include <vector>
namespace poet {
struct SpeciesErrorMetrics;
}
int write_checkpoint(const std::string &dir_path, const std::string &file_name, struct Checkpoint_s &&checkpoint);
int write_checkpoint(const std::string &dir_path, const std::string &file_name,
struct Checkpoint_s &&checkpoint);
int read_checkpoint(const std::string &dir_path, const std::string &file_name, struct Checkpoint_s &checkpoint);
int read_checkpoint(const std::string &dir_path, const std::string &file_name,
struct Checkpoint_s &checkpoint);
int write_metrics(const std::vector<poet::SpeciesErrorMetrics> &metrics_history,
const std::vector<std::string> &species_names,
const std::string &dir_path, const std::string &file_name);

View File

@ -1,62 +1,99 @@
#include "IO/StatsIO.hpp"
// #include "IO/StatsIO.hpp"
#include "Control/ControlModule.hpp"
#include "IO/HDF5Functions.hpp"
#include <cstdint>
#include <filesystem>
#include <fstream>
#include <highfive/H5Easy.hpp>
#include <highfive/highfive.hpp>
#include <iomanip>
#include <iostream>
#include <string>
#include <iomanip>
#include <filesystem>
#include <vector>
namespace poet
{
void writeStatsToCSV(const std::vector<SpeciesErrorMetrics> &all_stats,
const std::vector<std::string> &species_names,
const std::string &out_dir,
const std::string &filename)
{
std::filesystem::path full_path = std::filesystem::path(out_dir) / filename;
namespace fs = std::filesystem;
std::ofstream out(full_path);
if (!out.is_open())
{
std::cerr << "Could not open " << filename << " !" << std::endl;
return;
}
int write_metrics(const std::vector<poet::SpeciesErrorMetrics> &metrics_history,
const std::vector<std::string> &species_names,
const std::string &dir_path, const std::string &file_name) {
// header: CellID, Iteration, Rollback, Species, MAPE, RRMSE
out << std::left << std::setw(15) << "CellID"
<< std::setw(15) << "Iteration"
<< std::setw(15) << "Rollback"
<< std::setw(15) << "Species"
<< std::setw(15) << "MAPE"
<< std::setw(15) << "RRMSE" << "\n";
if (!fs::exists(dir_path)) {
std::cerr << "Directory does not exist: " << dir_path << std::endl;
return -1;
}
fs::path file_path = fs::path(dir_path) / file_name;
out << std::string(90, '-') << "\n";
H5Easy::File file(file_path, H5Easy::File::Truncate);
// data rows: iterate over iterations
for (size_t iter_idx = 0; iter_idx < all_stats.size(); ++iter_idx)
{
const auto &metrics = all_stats[iter_idx];
// Iterate over cells
for (size_t cell_idx = 0; cell_idx < metrics.id.size(); ++cell_idx)
{
// Iterate over species for this cell
for (size_t species_idx = 0; species_idx < species_names.size(); ++species_idx)
{
out << std::left
<< std::setw(15) << metrics.id[cell_idx]
<< std::setw(15) << metrics.iteration
<< std::setw(15) << metrics.rollback_count
<< std::setw(15) << species_names[species_idx]
<< std::setw(15) << metrics.mape[cell_idx][species_idx]
<< std::setw(15) << metrics.rrmse[cell_idx][species_idx]
<< "\n";
}
}
out << "\n";
}
for (size_t idx = 0; idx < metrics_history.size(); ++idx) {
const auto &metrics = metrics_history[idx];
std::string grp = "iter_" + std::to_string(metrics.iteration) + "_" +
std::to_string(metrics.rollback_count);
out.close();
std::cout << "Error metrics written to " << out_dir << "/" << filename << "\n";
}
size_t n_cells = metrics.id.size();
size_t n_species = metrics.mape[0].size();
H5Easy::dump(file, grp + "/meta", 0);
// Attach attributes
H5Easy::dumpAttribute(file, grp + "/meta", "species_names", species_names);
H5Easy::dumpAttribute(file, grp + "/meta", "iteration", metrics.iteration);
H5Easy::dumpAttribute(file, grp + "/meta", "rollback_count",
metrics.rollback_count);
H5Easy::dumpAttribute(file, grp + "/meta", "n_cells", n_cells);
H5Easy::dumpAttribute(file, grp + "/meta", "n_species", n_species);
// ─────────────────────────────────────────────
// 2. Real datasets
// ─────────────────────────────────────────────
H5Easy::dump(file, grp + "/cell_id", metrics.id);
H5Easy::dump(file, grp + "/mape", metrics.mape);
H5Easy::dump(file, grp + "/rrmse", metrics.rrmse);
}
return 0;
}
void writeStatsToCSV(const std::vector<poet::SpeciesErrorMetrics> &all_stats,
const std::vector<std::string> &species_names,
const std::string &out_dir, const std::string &filename) {
std::filesystem::path full_path = std::filesystem::path(out_dir) / filename;
std::ofstream out(full_path);
if (!out.is_open()) {
std::cerr << "Could not open " << filename << " !" << std::endl;
return;
}
// header: CellID, Iteration, Rollback, Species, MAPE, RRMSE
out << std::left << std::setw(15) << "CellID" << std::setw(15) << "Iteration"
<< std::setw(15) << "Rollback" << std::setw(15) << "Species"
<< std::setw(15) << "MAPE" << std::setw(15) << "RRMSE"
<< "\n";
out << std::string(90, '-') << "\n";
// data rows: iterate over iterations
for (size_t iter_idx = 0; iter_idx < all_stats.size(); ++iter_idx) {
const auto &metrics = all_stats[iter_idx];
// Iterate over cells
for (size_t cell_idx = 0; cell_idx < metrics.id.size(); ++cell_idx) {
// Iterate over species for this cell
for (size_t species_idx = 0; species_idx < species_names.size();
++species_idx) {
out << std::left << std::setw(15) << metrics.id[cell_idx]
<< std::setw(15) << metrics.iteration << std::setw(15)
<< metrics.rollback_count << std::setw(15)
<< species_names[species_idx] << std::setw(15)
<< metrics.mape[cell_idx][species_idx] << std::setw(15)
<< metrics.rrmse[cell_idx][species_idx] << "\n";
}
}
out << "\n";
}
out.close();
std::cout << "Error metrics written to " << out_dir << "/" << filename
<< "\n";
}
// namespace poet

View File

@ -1,10 +1,6 @@
#include <string>
#include "Control/ControlModule.hpp"
namespace poet
{
void writeStatsToCSV(const std::vector<SpeciesErrorMetrics> &all_stats,
const std::vector<std::string> &species_names,
const std::string &out_dir,
const std::string &filename);
} // namespace poet
void writeStatsToCSV(const std::vector<poet::SpeciesErrorMetrics> &all_stats,
const std::vector<std::string> &species_names,
const std::string &out_dir, const std::string &filename);
// namespace poet

View File

@ -1,42 +1,45 @@
#include "IO/Datatypes.hpp"
#include <cstdint>
#include <highfive/H5Easy.hpp>
#include <filesystem>
#include <highfive/H5Easy.hpp>
namespace fs = std::filesystem;
int write_checkpoint(const std::string &dir_path, const std::string &file_name, struct Checkpoint_s &&checkpoint){
if (!fs::exists(dir_path)) {
std::cerr << "Directory does not exist: " << dir_path << std::endl;
return -1;
}
fs::path file_path = fs::path(dir_path) / file_name;
// TODO: errorhandling
H5Easy::File file(file_path, H5Easy::File::Overwrite);
int write_checkpoint(const std::string &dir_path, const std::string &file_name,
struct Checkpoint_s &&checkpoint) {
H5Easy::dump(file, "/MetaParam/Iterations", checkpoint.iteration);
H5Easy::dump(file, "/Grid/Names", checkpoint.field.GetProps());
H5Easy::dump(file, "/Grid/Chemistry", checkpoint.field.As2DVector());
if (!fs::exists(dir_path)) {
std::cerr << "Directory does not exist: " << dir_path << std::endl;
return -1;
}
fs::path file_path = fs::path(dir_path) / file_name;
// TODO: errorhandling
H5Easy::File file(file_path, H5Easy::File::Overwrite);
return 0;
H5Easy::dump(file, "/MetaParam/Iterations", checkpoint.iteration);
H5Easy::dump(file, "/Grid/Names", checkpoint.field.GetProps());
H5Easy::dump(file, "/Grid/Chemistry", checkpoint.field.As2DVector());
return 0;
}
int read_checkpoint(const std::string &dir_path, const std::string &file_name, struct Checkpoint_s &checkpoint){
fs::path file_path = fs::path(dir_path) / file_name;
int read_checkpoint(const std::string &dir_path, const std::string &file_name,
struct Checkpoint_s &checkpoint) {
if (!fs::exists(file_path)) {
std::cerr << "File does not exist: " << file_path << std::endl;
return -1;
}
H5Easy::File file(file_path, H5Easy::File::ReadOnly);
fs::path file_path = fs::path(dir_path) / file_name;
checkpoint.iteration = H5Easy::load<uint32_t>(file, "/MetaParam/Iterations");
if (!fs::exists(file_path)) {
std::cerr << "File does not exist: " << file_path << std::endl;
return -1;
}
checkpoint.field = H5Easy::load<std::vector<std::vector<double>>>(file, "/Grid/Chemistry");
H5Easy::File file(file_path, H5Easy::File::ReadOnly);
return 0;
checkpoint.iteration = H5Easy::load<uint32_t>(file, "/MetaParam/Iterations");
checkpoint.field =
H5Easy::load<std::vector<std::vector<double>>>(file, "/Grid/Chemistry");
return 0;
}