mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-15 12:28:22 +01:00
added rollback_limit and fixed rollback logic
This commit is contained in:
parent
ac6bff6b97
commit
1de30ad0db
@ -1,5 +1,5 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=proto2_eps0035_no_rb_v2
|
||||
#SBATCH --job-name=proto2_eps0035
|
||||
#SBATCH --output=proto2_eps0035_no_rb_v2_%j.out
|
||||
#SBATCH --error=proto2_eps0035_no_rb_v2%j.err
|
||||
#SBATCH --partition=long
|
||||
@ -15,5 +15,5 @@ module purge
|
||||
module load cmake gcc openmpi
|
||||
|
||||
#mpirun -n 144 ./poet dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_only_pqc
|
||||
mpirun -n 144 ./poet --interp --rds dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 proto2_eps0035_no_rb_v2
|
||||
mpirun -n 144 ./poet --interp --rds dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 proto2_eps0035
|
||||
#mpirun -n 144 ./poet --interp barite_fgcs_4_new/barite_fgcs_4_new_rt.R barite_fgcs_4_new/barite_fgcs_4_new.qs2 barite
|
||||
Binary file not shown.
Binary file not shown.
@ -3,13 +3,12 @@
|
||||
#include "IO/HDF5Functions.hpp"
|
||||
#include "IO/StatsIO.hpp"
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
poet::ControlModule::ControlModule(const ControlConfig &config_)
|
||||
: config(config_) {}
|
||||
poet::ControlModule::ControlModule(const ControlConfig &config_) : config(config_) {}
|
||||
|
||||
void poet::ControlModule::beginIteration(ChemistryModule &chem, uint32_t &iter,
|
||||
const bool &dht_enabled,
|
||||
const bool &interp_enabled) {
|
||||
const bool &dht_enabled, const bool &interp_enabled) {
|
||||
|
||||
global_iteration = iter;
|
||||
double prep_a, prep_b;
|
||||
@ -22,35 +21,42 @@ void poet::ControlModule::beginIteration(ChemistryModule &chem, uint32_t &iter,
|
||||
}
|
||||
|
||||
/* Disables dht and/or interp during stabilzation phase */
|
||||
void poet::ControlModule::updateStabilizationPhase(ChemistryModule &chem,
|
||||
bool dht_enabled,
|
||||
void poet::ControlModule::updateStabilizationPhase(ChemistryModule &chem, bool dht_enabled,
|
||||
bool interp_enabled) {
|
||||
if (rollback_enabled) {
|
||||
if (disable_surr_counter > 0) {
|
||||
--disable_surr_counter;
|
||||
flush_request = false;
|
||||
std::cout << "Rollback counter: " << std::to_string(disable_surr_counter)
|
||||
<< std::endl;
|
||||
} else {
|
||||
bool in_warmup = (global_iteration <= config.stab_interval);
|
||||
bool rb_limit_reached = (rollback_count >= 3);
|
||||
|
||||
if (rollback_enabled && disable_surr_counter > 0) {
|
||||
--disable_surr_counter;
|
||||
std::cout << "Rollback counter: " << std::to_string(disable_surr_counter) << std::endl;
|
||||
if (disable_surr_counter == 0) {
|
||||
rollback_enabled = false;
|
||||
std::cout << "Rollback stabilization complete, re-enabling surrogate."
|
||||
<< std::endl;
|
||||
}
|
||||
flush_request = false;
|
||||
}
|
||||
|
||||
if (global_iteration <= config.stab_interval || rollback_enabled) {
|
||||
chem.SetStabEnabled(true);
|
||||
/* disable surrogates during warmup, active rollback or after limit */
|
||||
if (in_warmup || rollback_enabled || rb_limit_reached) {
|
||||
chem.SetStabEnabled(!rb_limit_reached);
|
||||
chem.SetDhtEnabled(false);
|
||||
chem.SetInterpEnabled(false);
|
||||
} else {
|
||||
chem.SetStabEnabled(false);
|
||||
chem.SetDhtEnabled(dht_enabled);
|
||||
chem.SetInterpEnabled(interp_enabled);
|
||||
|
||||
if (rb_limit_reached) {
|
||||
std::cout << "Interpolation completly disabled." << std::endl;
|
||||
} else {
|
||||
std::cout << "In stabilization phase." << std::endl;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
/* enable user-requested surrogates */
|
||||
chem.SetStabEnabled(false);
|
||||
chem.SetDhtEnabled(dht_enabled);
|
||||
chem.SetInterpEnabled(interp_enabled);
|
||||
std::cout << "Interpolating." << std::endl;
|
||||
}
|
||||
|
||||
void poet::ControlModule::writeCheckpoint(DiffusionModule &diffusion,
|
||||
uint32_t &iter,
|
||||
void poet::ControlModule::writeCheckpoint(DiffusionModule &diffusion, uint32_t &iter,
|
||||
const std::string &out_dir) {
|
||||
if (global_iteration % config.checkpoint_interval == 0) {
|
||||
double w_check_a = MPI_Wtime();
|
||||
@ -62,44 +68,41 @@ void poet::ControlModule::writeCheckpoint(DiffusionModule &diffusion,
|
||||
}
|
||||
}
|
||||
|
||||
void poet::ControlModule::readCheckpoint(DiffusionModule &diffusion,
|
||||
uint32_t ¤t_iter,
|
||||
uint32_t rollback_iter,
|
||||
const std::string &out_dir) {
|
||||
void poet::ControlModule::readCheckpoint(DiffusionModule &diffusion, uint32_t ¤t_iter,
|
||||
uint32_t rollback_iter, const std::string &out_dir) {
|
||||
double r_check_a, r_check_b;
|
||||
|
||||
r_check_a = MPI_Wtime();
|
||||
Checkpoint_s checkpoint_read{.field = diffusion.getField()};
|
||||
read_checkpoint(out_dir,
|
||||
"checkpoint" + std::to_string(rollback_iter) + ".hdf5",
|
||||
checkpoint_read);
|
||||
read_checkpoint(out_dir, "checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read);
|
||||
current_iter = checkpoint_read.iteration;
|
||||
r_check_b = MPI_Wtime();
|
||||
r_check_t += r_check_b - r_check_a;
|
||||
}
|
||||
|
||||
void poet::ControlModule::writeErrorMetrics(
|
||||
uint32_t &iter, const std::string &out_dir,
|
||||
const std::vector<std::string> &species) {
|
||||
void poet::ControlModule::writeErrorMetrics(uint32_t &iter, const std::string &out_dir,
|
||||
const std::vector<std::string> &species) {
|
||||
|
||||
if (rollback_count >= 3) {
|
||||
return;
|
||||
}
|
||||
|
||||
double stats_a = MPI_Wtime();
|
||||
writeSpeciesStatsToCSV(metrics_history, species, out_dir,
|
||||
"species_overview.csv");
|
||||
write_metrics(cell_metrics_history, species, out_dir,
|
||||
"metrics_overview.hdf5");
|
||||
writeSpeciesStatsToCSV(metrics_history, species, out_dir, "species_overview.csv");
|
||||
write_metrics(cell_metrics_history, species, out_dir, "metrics_overview.hdf5");
|
||||
double stats_b = MPI_Wtime();
|
||||
this->stats_t += stats_b - stats_a;
|
||||
}
|
||||
|
||||
uint32_t poet::ControlModule::getRollbackIter() {
|
||||
|
||||
uint32_t last_iter = ((global_iteration - 1) / config.checkpoint_interval) *
|
||||
config.checkpoint_interval;
|
||||
uint32_t last_iter =
|
||||
((global_iteration - 1) / config.checkpoint_interval) * config.checkpoint_interval;
|
||||
return last_iter;
|
||||
}
|
||||
|
||||
std::optional<uint32_t> poet::ControlModule::getRollbackTarget(
|
||||
const std::vector<std::string> &species) {
|
||||
std::optional<uint32_t>
|
||||
poet::ControlModule::getRollbackTarget(const std::vector<std::string> &species) {
|
||||
double r_check_a, r_check_b;
|
||||
|
||||
/* Skip threshold checking if already in stabilization phase*/
|
||||
@ -114,16 +117,16 @@ std::optional<uint32_t> poet::ControlModule::getRollbackTarget(
|
||||
|
||||
/* check bounds of threshold vector*/
|
||||
if (sp_i >= config.mape_threshold.size()) {
|
||||
std::cerr << "Warning: No threshold defined for species " << species[sp_i]
|
||||
<< " at index " << std::to_string(sp_i) << std::endl;
|
||||
std::cerr << "Warning: No threshold defined for species " << species[sp_i] << " at index "
|
||||
<< std::to_string(sp_i) << std::endl;
|
||||
continue;
|
||||
}
|
||||
if (s_hist.mape[sp_i] > config.mape_threshold[sp_i]) {
|
||||
|
||||
const auto &c_hist = cell_metrics_history.back();
|
||||
auto max_it = std::max_element(
|
||||
c_hist.mape.begin(), c_hist.mape.end(),
|
||||
[sp_i](const auto &a, const auto &b) { return a[sp_i] < b[sp_i]; });
|
||||
auto max_it =
|
||||
std::max_element(c_hist.mape.begin(), c_hist.mape.end(),
|
||||
[sp_i](const auto &a, const auto &b) { return a[sp_i] < b[sp_i]; });
|
||||
|
||||
size_t max_idx = std::distance(c_hist.mape.begin(), max_it);
|
||||
uint32_t cell_id = c_hist.id[max_idx];
|
||||
@ -134,10 +137,8 @@ std::optional<uint32_t> poet::ControlModule::getRollbackTarget(
|
||||
|
||||
std::cout << "Threshold exceeded for " << species[sp_i]
|
||||
<< " with species-level MAPE = " << std::to_string(s_hist.mape[sp_i])
|
||||
<< " exceeding threshold = "
|
||||
<< std::to_string(config.mape_threshold[sp_i])
|
||||
<< " exceeding threshold = " << std::to_string(config.mape_threshold[sp_i])
|
||||
<< ". Worst cell: ID=" << std::to_string(cell_id)
|
||||
<< ", species= " << species[sp_i]
|
||||
<< " with MAPE=" << std::to_string(cell_mape) << std::endl;
|
||||
|
||||
return getRollbackIter();
|
||||
@ -146,17 +147,20 @@ std::optional<uint32_t> poet::ControlModule::getRollbackTarget(
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
void poet::ControlModule::computeErrorMetrics(
|
||||
std::vector<std::vector<double>> &reference_values,
|
||||
std::vector<std::vector<double>> &surrogate_values,
|
||||
const std::vector<std::string> &species, const uint32_t size_per_prop) {
|
||||
void poet::ControlModule::computeErrorMetrics(std::vector<std::vector<double>> &reference_values,
|
||||
std::vector<std::vector<double>> &surrogate_values,
|
||||
const std::vector<std::string> &species,
|
||||
const uint32_t size_per_prop) {
|
||||
|
||||
if (rollback_count >= 3) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t n_cells = reference_values.size();
|
||||
const uint32_t n_species = species.size();
|
||||
const double ZERO_ABS = config.zero_abs;
|
||||
|
||||
CellErrorMetrics c_metrics(n_cells, n_species, global_iteration,
|
||||
rollback_count);
|
||||
CellErrorMetrics c_metrics(n_cells, n_species, global_iteration, rollback_count);
|
||||
SpeciesErrorMetrics s_metrics(n_species, global_iteration, rollback_count);
|
||||
|
||||
std::vector<double> species_err_sum(n_species, 0.0);
|
||||
@ -194,11 +198,10 @@ void poet::ControlModule::computeErrorMetrics(
|
||||
c_metrics.rrmse[cell_i][sp_i] = alpha * alpha;
|
||||
// Log extreme MAPE values for debugging
|
||||
if (c_metrics.mape[cell_i][sp_i] > 100.0) {
|
||||
std::cout << "WARNING: High MAPE detected - Cell="
|
||||
<< c_metrics.id[cell_i] << ", Species=" << species[sp_i]
|
||||
<< ", MAPE=" << c_metrics.mape[cell_i][sp_i]
|
||||
<< "%, Ref=" << ref_value << ", Sur=" << sur_value
|
||||
<< ", Alpha=" << alpha << std::endl;
|
||||
std::cout << "WARNING: High MAPE detected - Cell=" << c_metrics.id[cell_i]
|
||||
<< ", Species=" << species[sp_i] << ", MAPE=" << c_metrics.mape[cell_i][sp_i]
|
||||
<< "%, Ref=" << ref_value << ", Sur=" << sur_value << ", Alpha=" << alpha
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -208,19 +211,43 @@ void poet::ControlModule::computeErrorMetrics(
|
||||
s_metrics.rrmse[sp_i] = std::sqrt(species_sqr_sum[sp_i] / size_per_prop);
|
||||
}
|
||||
|
||||
// Sort cell metrics by ID
|
||||
std::vector<size_t> indices(n_cells);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
std::sort(indices.begin(), indices.end(),
|
||||
[&c_metrics](size_t a, size_t b) { return c_metrics.id[a] < c_metrics.id[b]; });
|
||||
|
||||
// Reorder cell metrics based on sorted indices
|
||||
std::vector<uint32_t> sorted_ids(n_cells);
|
||||
std::vector<std::vector<double>> sorted_mape(n_cells, std::vector<double>(n_species));
|
||||
std::vector<std::vector<double>> sorted_rrmse(n_cells, std::vector<double>(n_species));
|
||||
|
||||
for (size_t i = 0; i < n_cells; i++) {
|
||||
sorted_ids[i] = c_metrics.id[indices[i]];
|
||||
sorted_mape[i] = c_metrics.mape[indices[i]];
|
||||
sorted_rrmse[i] = c_metrics.rrmse[indices[i]];
|
||||
}
|
||||
|
||||
c_metrics.id = std::move(sorted_ids);
|
||||
c_metrics.mape = std::move(sorted_mape);
|
||||
c_metrics.rrmse = std::move(sorted_rrmse);
|
||||
|
||||
metrics_history.push_back(s_metrics);
|
||||
cell_metrics_history.push_back(c_metrics);
|
||||
}
|
||||
|
||||
void poet::ControlModule::processCheckpoint(
|
||||
DiffusionModule &diffusion, uint32_t ¤t_iter,
|
||||
const std::string &out_dir, const std::vector<std::string> &species) {
|
||||
void poet::ControlModule::processCheckpoint(DiffusionModule &diffusion, uint32_t ¤t_iter,
|
||||
const std::string &out_dir,
|
||||
const std::vector<std::string> &species) {
|
||||
|
||||
// Use max_rollbacks from config, default to 3 if not set
|
||||
// uint32_t max_rollbacks =
|
||||
// (config.max_rollbacks > 0) ? config.max_rollbacks : 3;
|
||||
|
||||
if (flush_request /* && rollback_count < 3 */) {
|
||||
if (rollback_count >= 3) {
|
||||
return;
|
||||
}
|
||||
if (flush_request && rollback_count < 3) {
|
||||
uint32_t target = getRollbackIter();
|
||||
readCheckpoint(diffusion, current_iter, target, out_dir);
|
||||
|
||||
@ -228,8 +255,7 @@ void poet::ControlModule::processCheckpoint(
|
||||
rollback_count++;
|
||||
disable_surr_counter = config.stab_interval;
|
||||
|
||||
std::cout << "Restored checkpoint " << std::to_string(target)
|
||||
<< ", surrogate disabled for "
|
||||
std::cout << "Restored checkpoint " << std::to_string(target) << ", surrogate disabled for "
|
||||
<< std::to_string(config.stab_interval) << std::endl;
|
||||
} else {
|
||||
writeCheckpoint(diffusion, global_iteration, out_dir);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user