diff --git a/bin/run_poet.sh b/bin/run_poet.sh index 7ae67fa84..97b9a33fa 100644 --- a/bin/run_poet.sh +++ b/bin/run_poet.sh @@ -1,5 +1,5 @@ #!/bin/bash -#SBATCH --job-name=proto2_eps0035_no_rb_v2 +#SBATCH --job-name=proto2_eps0035 #SBATCH --output=proto2_eps0035_no_rb_v2_%j.out #SBATCH --error=proto2_eps0035_no_rb_v2%j.err #SBATCH --partition=long @@ -15,5 +15,5 @@ module purge module load cmake gcc openmpi #mpirun -n 144 ./poet dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_only_pqc -mpirun -n 144 ./poet --interp --rds dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 proto2_eps0035_no_rb_v2 +mpirun -n 144 ./poet --interp --rds dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 proto2_eps0035 #mpirun -n 144 ./poet --interp barite_fgcs_4_new/barite_fgcs_4_new_rt.R barite_fgcs_4_new/barite_fgcs_4_new.qs2 barite \ No newline at end of file diff --git a/share/poet/barite/barite_het.qs2 b/share/poet/barite/barite_het.qs2 index 1534d8aa0..f868a3731 100644 Binary files a/share/poet/barite/barite_het.qs2 and b/share/poet/barite/barite_het.qs2 differ diff --git a/share/poet/surfex/PoetEGU_surfex_500.qs2 b/share/poet/surfex/PoetEGU_surfex_500.qs2 index 19c7da936..3288084dc 100644 Binary files a/share/poet/surfex/PoetEGU_surfex_500.qs2 and b/share/poet/surfex/PoetEGU_surfex_500.qs2 differ diff --git a/src/Control/ControlModule.cpp b/src/Control/ControlModule.cpp index 41d24eb8d..394cb429b 100644 --- a/src/Control/ControlModule.cpp +++ b/src/Control/ControlModule.cpp @@ -3,13 +3,12 @@ #include "IO/HDF5Functions.hpp" #include "IO/StatsIO.hpp" #include +#include -poet::ControlModule::ControlModule(const ControlConfig &config_) - : config(config_) {} +poet::ControlModule::ControlModule(const ControlConfig &config_) : config(config_) {} void poet::ControlModule::beginIteration(ChemistryModule &chem, uint32_t &iter, - const bool &dht_enabled, - const bool &interp_enabled) { + const bool &dht_enabled, const bool &interp_enabled) { global_iteration = iter; double prep_a, prep_b; @@ -22,35 +21,42 @@ void poet::ControlModule::beginIteration(ChemistryModule &chem, uint32_t &iter, } /* Disables dht and/or interp during stabilzation phase */ -void poet::ControlModule::updateStabilizationPhase(ChemistryModule &chem, - bool dht_enabled, +void poet::ControlModule::updateStabilizationPhase(ChemistryModule &chem, bool dht_enabled, bool interp_enabled) { - if (rollback_enabled) { - if (disable_surr_counter > 0) { - --disable_surr_counter; - flush_request = false; - std::cout << "Rollback counter: " << std::to_string(disable_surr_counter) - << std::endl; - } else { + bool in_warmup = (global_iteration <= config.stab_interval); + bool rb_limit_reached = (rollback_count >= 3); + + if (rollback_enabled && disable_surr_counter > 0) { + --disable_surr_counter; + std::cout << "Rollback counter: " << std::to_string(disable_surr_counter) << std::endl; + if (disable_surr_counter == 0) { rollback_enabled = false; - std::cout << "Rollback stabilization complete, re-enabling surrogate." - << std::endl; } + flush_request = false; } - if (global_iteration <= config.stab_interval || rollback_enabled) { - chem.SetStabEnabled(true); + /* disable surrogates during warmup, active rollback or after limit */ + if (in_warmup || rollback_enabled || rb_limit_reached) { + chem.SetStabEnabled(!rb_limit_reached); chem.SetDhtEnabled(false); chem.SetInterpEnabled(false); - } else { - chem.SetStabEnabled(false); - chem.SetDhtEnabled(dht_enabled); - chem.SetInterpEnabled(interp_enabled); + + if (rb_limit_reached) { + std::cout << "Interpolation completly disabled." << std::endl; + } else { + std::cout << "In stabilization phase." << std::endl; + } + return; } + + /* enable user-requested surrogates */ + chem.SetStabEnabled(false); + chem.SetDhtEnabled(dht_enabled); + chem.SetInterpEnabled(interp_enabled); + std::cout << "Interpolating." << std::endl; } -void poet::ControlModule::writeCheckpoint(DiffusionModule &diffusion, - uint32_t &iter, +void poet::ControlModule::writeCheckpoint(DiffusionModule &diffusion, uint32_t &iter, const std::string &out_dir) { if (global_iteration % config.checkpoint_interval == 0) { double w_check_a = MPI_Wtime(); @@ -62,44 +68,41 @@ void poet::ControlModule::writeCheckpoint(DiffusionModule &diffusion, } } -void poet::ControlModule::readCheckpoint(DiffusionModule &diffusion, - uint32_t ¤t_iter, - uint32_t rollback_iter, - const std::string &out_dir) { +void poet::ControlModule::readCheckpoint(DiffusionModule &diffusion, uint32_t ¤t_iter, + uint32_t rollback_iter, const std::string &out_dir) { double r_check_a, r_check_b; r_check_a = MPI_Wtime(); Checkpoint_s checkpoint_read{.field = diffusion.getField()}; - read_checkpoint(out_dir, - "checkpoint" + std::to_string(rollback_iter) + ".hdf5", - checkpoint_read); + read_checkpoint(out_dir, "checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read); current_iter = checkpoint_read.iteration; r_check_b = MPI_Wtime(); r_check_t += r_check_b - r_check_a; } -void poet::ControlModule::writeErrorMetrics( - uint32_t &iter, const std::string &out_dir, - const std::vector &species) { +void poet::ControlModule::writeErrorMetrics(uint32_t &iter, const std::string &out_dir, + const std::vector &species) { + + if (rollback_count >= 3) { + return; + } double stats_a = MPI_Wtime(); - writeSpeciesStatsToCSV(metrics_history, species, out_dir, - "species_overview.csv"); - write_metrics(cell_metrics_history, species, out_dir, - "metrics_overview.hdf5"); + writeSpeciesStatsToCSV(metrics_history, species, out_dir, "species_overview.csv"); + write_metrics(cell_metrics_history, species, out_dir, "metrics_overview.hdf5"); double stats_b = MPI_Wtime(); this->stats_t += stats_b - stats_a; } uint32_t poet::ControlModule::getRollbackIter() { - uint32_t last_iter = ((global_iteration - 1) / config.checkpoint_interval) * - config.checkpoint_interval; + uint32_t last_iter = + ((global_iteration - 1) / config.checkpoint_interval) * config.checkpoint_interval; return last_iter; } -std::optional poet::ControlModule::getRollbackTarget( - const std::vector &species) { +std::optional +poet::ControlModule::getRollbackTarget(const std::vector &species) { double r_check_a, r_check_b; /* Skip threshold checking if already in stabilization phase*/ @@ -114,16 +117,16 @@ std::optional poet::ControlModule::getRollbackTarget( /* check bounds of threshold vector*/ if (sp_i >= config.mape_threshold.size()) { - std::cerr << "Warning: No threshold defined for species " << species[sp_i] - << " at index " << std::to_string(sp_i) << std::endl; + std::cerr << "Warning: No threshold defined for species " << species[sp_i] << " at index " + << std::to_string(sp_i) << std::endl; continue; } if (s_hist.mape[sp_i] > config.mape_threshold[sp_i]) { const auto &c_hist = cell_metrics_history.back(); - auto max_it = std::max_element( - c_hist.mape.begin(), c_hist.mape.end(), - [sp_i](const auto &a, const auto &b) { return a[sp_i] < b[sp_i]; }); + auto max_it = + std::max_element(c_hist.mape.begin(), c_hist.mape.end(), + [sp_i](const auto &a, const auto &b) { return a[sp_i] < b[sp_i]; }); size_t max_idx = std::distance(c_hist.mape.begin(), max_it); uint32_t cell_id = c_hist.id[max_idx]; @@ -134,10 +137,8 @@ std::optional poet::ControlModule::getRollbackTarget( std::cout << "Threshold exceeded for " << species[sp_i] << " with species-level MAPE = " << std::to_string(s_hist.mape[sp_i]) - << " exceeding threshold = " - << std::to_string(config.mape_threshold[sp_i]) + << " exceeding threshold = " << std::to_string(config.mape_threshold[sp_i]) << ". Worst cell: ID=" << std::to_string(cell_id) - << ", species= " << species[sp_i] << " with MAPE=" << std::to_string(cell_mape) << std::endl; return getRollbackIter(); @@ -146,17 +147,20 @@ std::optional poet::ControlModule::getRollbackTarget( return std::nullopt; } -void poet::ControlModule::computeErrorMetrics( - std::vector> &reference_values, - std::vector> &surrogate_values, - const std::vector &species, const uint32_t size_per_prop) { +void poet::ControlModule::computeErrorMetrics(std::vector> &reference_values, + std::vector> &surrogate_values, + const std::vector &species, + const uint32_t size_per_prop) { + + if (rollback_count >= 3) { + return; + } const uint32_t n_cells = reference_values.size(); const uint32_t n_species = species.size(); const double ZERO_ABS = config.zero_abs; - CellErrorMetrics c_metrics(n_cells, n_species, global_iteration, - rollback_count); + CellErrorMetrics c_metrics(n_cells, n_species, global_iteration, rollback_count); SpeciesErrorMetrics s_metrics(n_species, global_iteration, rollback_count); std::vector species_err_sum(n_species, 0.0); @@ -194,11 +198,10 @@ void poet::ControlModule::computeErrorMetrics( c_metrics.rrmse[cell_i][sp_i] = alpha * alpha; // Log extreme MAPE values for debugging if (c_metrics.mape[cell_i][sp_i] > 100.0) { - std::cout << "WARNING: High MAPE detected - Cell=" - << c_metrics.id[cell_i] << ", Species=" << species[sp_i] - << ", MAPE=" << c_metrics.mape[cell_i][sp_i] - << "%, Ref=" << ref_value << ", Sur=" << sur_value - << ", Alpha=" << alpha << std::endl; + std::cout << "WARNING: High MAPE detected - Cell=" << c_metrics.id[cell_i] + << ", Species=" << species[sp_i] << ", MAPE=" << c_metrics.mape[cell_i][sp_i] + << "%, Ref=" << ref_value << ", Sur=" << sur_value << ", Alpha=" << alpha + << std::endl; } } } @@ -208,19 +211,43 @@ void poet::ControlModule::computeErrorMetrics( s_metrics.rrmse[sp_i] = std::sqrt(species_sqr_sum[sp_i] / size_per_prop); } + // Sort cell metrics by ID + std::vector indices(n_cells); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), + [&c_metrics](size_t a, size_t b) { return c_metrics.id[a] < c_metrics.id[b]; }); + + // Reorder cell metrics based on sorted indices + std::vector sorted_ids(n_cells); + std::vector> sorted_mape(n_cells, std::vector(n_species)); + std::vector> sorted_rrmse(n_cells, std::vector(n_species)); + + for (size_t i = 0; i < n_cells; i++) { + sorted_ids[i] = c_metrics.id[indices[i]]; + sorted_mape[i] = c_metrics.mape[indices[i]]; + sorted_rrmse[i] = c_metrics.rrmse[indices[i]]; + } + + c_metrics.id = std::move(sorted_ids); + c_metrics.mape = std::move(sorted_mape); + c_metrics.rrmse = std::move(sorted_rrmse); + metrics_history.push_back(s_metrics); cell_metrics_history.push_back(c_metrics); } -void poet::ControlModule::processCheckpoint( - DiffusionModule &diffusion, uint32_t ¤t_iter, - const std::string &out_dir, const std::vector &species) { +void poet::ControlModule::processCheckpoint(DiffusionModule &diffusion, uint32_t ¤t_iter, + const std::string &out_dir, + const std::vector &species) { // Use max_rollbacks from config, default to 3 if not set // uint32_t max_rollbacks = // (config.max_rollbacks > 0) ? config.max_rollbacks : 3; - if (flush_request /* && rollback_count < 3 */) { + if (rollback_count >= 3) { + return; + } + if (flush_request && rollback_count < 3) { uint32_t target = getRollbackIter(); readCheckpoint(diffusion, current_iter, target, out_dir); @@ -228,8 +255,7 @@ void poet::ControlModule::processCheckpoint( rollback_count++; disable_surr_counter = config.stab_interval; - std::cout << "Restored checkpoint " << std::to_string(target) - << ", surrogate disabled for " + std::cout << "Restored checkpoint " << std::to_string(target) << ", surrogate disabled for " << std::to_string(config.stab_interval) << std::endl; } else { writeCheckpoint(diffusion, global_iteration, out_dir);