mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-15 20:38:23 +01:00
added rollback_limit and fixed rollback logic
This commit is contained in:
parent
ac6bff6b97
commit
1de30ad0db
@ -1,5 +1,5 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
#SBATCH --job-name=proto2_eps0035_no_rb_v2
|
#SBATCH --job-name=proto2_eps0035
|
||||||
#SBATCH --output=proto2_eps0035_no_rb_v2_%j.out
|
#SBATCH --output=proto2_eps0035_no_rb_v2_%j.out
|
||||||
#SBATCH --error=proto2_eps0035_no_rb_v2%j.err
|
#SBATCH --error=proto2_eps0035_no_rb_v2%j.err
|
||||||
#SBATCH --partition=long
|
#SBATCH --partition=long
|
||||||
@ -15,5 +15,5 @@ module purge
|
|||||||
module load cmake gcc openmpi
|
module load cmake gcc openmpi
|
||||||
|
|
||||||
#mpirun -n 144 ./poet dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_only_pqc
|
#mpirun -n 144 ./poet dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_only_pqc
|
||||||
mpirun -n 144 ./poet --interp --rds dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 proto2_eps0035_no_rb_v2
|
mpirun -n 144 ./poet --interp --rds dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 proto2_eps0035
|
||||||
#mpirun -n 144 ./poet --interp barite_fgcs_4_new/barite_fgcs_4_new_rt.R barite_fgcs_4_new/barite_fgcs_4_new.qs2 barite
|
#mpirun -n 144 ./poet --interp barite_fgcs_4_new/barite_fgcs_4_new_rt.R barite_fgcs_4_new/barite_fgcs_4_new.qs2 barite
|
||||||
Binary file not shown.
Binary file not shown.
@ -3,13 +3,12 @@
|
|||||||
#include "IO/HDF5Functions.hpp"
|
#include "IO/HDF5Functions.hpp"
|
||||||
#include "IO/StatsIO.hpp"
|
#include "IO/StatsIO.hpp"
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
poet::ControlModule::ControlModule(const ControlConfig &config_)
|
poet::ControlModule::ControlModule(const ControlConfig &config_) : config(config_) {}
|
||||||
: config(config_) {}
|
|
||||||
|
|
||||||
void poet::ControlModule::beginIteration(ChemistryModule &chem, uint32_t &iter,
|
void poet::ControlModule::beginIteration(ChemistryModule &chem, uint32_t &iter,
|
||||||
const bool &dht_enabled,
|
const bool &dht_enabled, const bool &interp_enabled) {
|
||||||
const bool &interp_enabled) {
|
|
||||||
|
|
||||||
global_iteration = iter;
|
global_iteration = iter;
|
||||||
double prep_a, prep_b;
|
double prep_a, prep_b;
|
||||||
@ -22,35 +21,42 @@ void poet::ControlModule::beginIteration(ChemistryModule &chem, uint32_t &iter,
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* Disables dht and/or interp during stabilzation phase */
|
/* Disables dht and/or interp during stabilzation phase */
|
||||||
void poet::ControlModule::updateStabilizationPhase(ChemistryModule &chem,
|
void poet::ControlModule::updateStabilizationPhase(ChemistryModule &chem, bool dht_enabled,
|
||||||
bool dht_enabled,
|
|
||||||
bool interp_enabled) {
|
bool interp_enabled) {
|
||||||
if (rollback_enabled) {
|
bool in_warmup = (global_iteration <= config.stab_interval);
|
||||||
if (disable_surr_counter > 0) {
|
bool rb_limit_reached = (rollback_count >= 3);
|
||||||
--disable_surr_counter;
|
|
||||||
flush_request = false;
|
if (rollback_enabled && disable_surr_counter > 0) {
|
||||||
std::cout << "Rollback counter: " << std::to_string(disable_surr_counter)
|
--disable_surr_counter;
|
||||||
<< std::endl;
|
std::cout << "Rollback counter: " << std::to_string(disable_surr_counter) << std::endl;
|
||||||
} else {
|
if (disable_surr_counter == 0) {
|
||||||
rollback_enabled = false;
|
rollback_enabled = false;
|
||||||
std::cout << "Rollback stabilization complete, re-enabling surrogate."
|
|
||||||
<< std::endl;
|
|
||||||
}
|
}
|
||||||
|
flush_request = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (global_iteration <= config.stab_interval || rollback_enabled) {
|
/* disable surrogates during warmup, active rollback or after limit */
|
||||||
chem.SetStabEnabled(true);
|
if (in_warmup || rollback_enabled || rb_limit_reached) {
|
||||||
|
chem.SetStabEnabled(!rb_limit_reached);
|
||||||
chem.SetDhtEnabled(false);
|
chem.SetDhtEnabled(false);
|
||||||
chem.SetInterpEnabled(false);
|
chem.SetInterpEnabled(false);
|
||||||
} else {
|
|
||||||
chem.SetStabEnabled(false);
|
if (rb_limit_reached) {
|
||||||
chem.SetDhtEnabled(dht_enabled);
|
std::cout << "Interpolation completly disabled." << std::endl;
|
||||||
chem.SetInterpEnabled(interp_enabled);
|
} else {
|
||||||
|
std::cout << "In stabilization phase." << std::endl;
|
||||||
|
}
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* enable user-requested surrogates */
|
||||||
|
chem.SetStabEnabled(false);
|
||||||
|
chem.SetDhtEnabled(dht_enabled);
|
||||||
|
chem.SetInterpEnabled(interp_enabled);
|
||||||
|
std::cout << "Interpolating." << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
void poet::ControlModule::writeCheckpoint(DiffusionModule &diffusion,
|
void poet::ControlModule::writeCheckpoint(DiffusionModule &diffusion, uint32_t &iter,
|
||||||
uint32_t &iter,
|
|
||||||
const std::string &out_dir) {
|
const std::string &out_dir) {
|
||||||
if (global_iteration % config.checkpoint_interval == 0) {
|
if (global_iteration % config.checkpoint_interval == 0) {
|
||||||
double w_check_a = MPI_Wtime();
|
double w_check_a = MPI_Wtime();
|
||||||
@ -62,44 +68,41 @@ void poet::ControlModule::writeCheckpoint(DiffusionModule &diffusion,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void poet::ControlModule::readCheckpoint(DiffusionModule &diffusion,
|
void poet::ControlModule::readCheckpoint(DiffusionModule &diffusion, uint32_t ¤t_iter,
|
||||||
uint32_t ¤t_iter,
|
uint32_t rollback_iter, const std::string &out_dir) {
|
||||||
uint32_t rollback_iter,
|
|
||||||
const std::string &out_dir) {
|
|
||||||
double r_check_a, r_check_b;
|
double r_check_a, r_check_b;
|
||||||
|
|
||||||
r_check_a = MPI_Wtime();
|
r_check_a = MPI_Wtime();
|
||||||
Checkpoint_s checkpoint_read{.field = diffusion.getField()};
|
Checkpoint_s checkpoint_read{.field = diffusion.getField()};
|
||||||
read_checkpoint(out_dir,
|
read_checkpoint(out_dir, "checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read);
|
||||||
"checkpoint" + std::to_string(rollback_iter) + ".hdf5",
|
|
||||||
checkpoint_read);
|
|
||||||
current_iter = checkpoint_read.iteration;
|
current_iter = checkpoint_read.iteration;
|
||||||
r_check_b = MPI_Wtime();
|
r_check_b = MPI_Wtime();
|
||||||
r_check_t += r_check_b - r_check_a;
|
r_check_t += r_check_b - r_check_a;
|
||||||
}
|
}
|
||||||
|
|
||||||
void poet::ControlModule::writeErrorMetrics(
|
void poet::ControlModule::writeErrorMetrics(uint32_t &iter, const std::string &out_dir,
|
||||||
uint32_t &iter, const std::string &out_dir,
|
const std::vector<std::string> &species) {
|
||||||
const std::vector<std::string> &species) {
|
|
||||||
|
if (rollback_count >= 3) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
double stats_a = MPI_Wtime();
|
double stats_a = MPI_Wtime();
|
||||||
writeSpeciesStatsToCSV(metrics_history, species, out_dir,
|
writeSpeciesStatsToCSV(metrics_history, species, out_dir, "species_overview.csv");
|
||||||
"species_overview.csv");
|
write_metrics(cell_metrics_history, species, out_dir, "metrics_overview.hdf5");
|
||||||
write_metrics(cell_metrics_history, species, out_dir,
|
|
||||||
"metrics_overview.hdf5");
|
|
||||||
double stats_b = MPI_Wtime();
|
double stats_b = MPI_Wtime();
|
||||||
this->stats_t += stats_b - stats_a;
|
this->stats_t += stats_b - stats_a;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t poet::ControlModule::getRollbackIter() {
|
uint32_t poet::ControlModule::getRollbackIter() {
|
||||||
|
|
||||||
uint32_t last_iter = ((global_iteration - 1) / config.checkpoint_interval) *
|
uint32_t last_iter =
|
||||||
config.checkpoint_interval;
|
((global_iteration - 1) / config.checkpoint_interval) * config.checkpoint_interval;
|
||||||
return last_iter;
|
return last_iter;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<uint32_t> poet::ControlModule::getRollbackTarget(
|
std::optional<uint32_t>
|
||||||
const std::vector<std::string> &species) {
|
poet::ControlModule::getRollbackTarget(const std::vector<std::string> &species) {
|
||||||
double r_check_a, r_check_b;
|
double r_check_a, r_check_b;
|
||||||
|
|
||||||
/* Skip threshold checking if already in stabilization phase*/
|
/* Skip threshold checking if already in stabilization phase*/
|
||||||
@ -114,16 +117,16 @@ std::optional<uint32_t> poet::ControlModule::getRollbackTarget(
|
|||||||
|
|
||||||
/* check bounds of threshold vector*/
|
/* check bounds of threshold vector*/
|
||||||
if (sp_i >= config.mape_threshold.size()) {
|
if (sp_i >= config.mape_threshold.size()) {
|
||||||
std::cerr << "Warning: No threshold defined for species " << species[sp_i]
|
std::cerr << "Warning: No threshold defined for species " << species[sp_i] << " at index "
|
||||||
<< " at index " << std::to_string(sp_i) << std::endl;
|
<< std::to_string(sp_i) << std::endl;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (s_hist.mape[sp_i] > config.mape_threshold[sp_i]) {
|
if (s_hist.mape[sp_i] > config.mape_threshold[sp_i]) {
|
||||||
|
|
||||||
const auto &c_hist = cell_metrics_history.back();
|
const auto &c_hist = cell_metrics_history.back();
|
||||||
auto max_it = std::max_element(
|
auto max_it =
|
||||||
c_hist.mape.begin(), c_hist.mape.end(),
|
std::max_element(c_hist.mape.begin(), c_hist.mape.end(),
|
||||||
[sp_i](const auto &a, const auto &b) { return a[sp_i] < b[sp_i]; });
|
[sp_i](const auto &a, const auto &b) { return a[sp_i] < b[sp_i]; });
|
||||||
|
|
||||||
size_t max_idx = std::distance(c_hist.mape.begin(), max_it);
|
size_t max_idx = std::distance(c_hist.mape.begin(), max_it);
|
||||||
uint32_t cell_id = c_hist.id[max_idx];
|
uint32_t cell_id = c_hist.id[max_idx];
|
||||||
@ -134,10 +137,8 @@ std::optional<uint32_t> poet::ControlModule::getRollbackTarget(
|
|||||||
|
|
||||||
std::cout << "Threshold exceeded for " << species[sp_i]
|
std::cout << "Threshold exceeded for " << species[sp_i]
|
||||||
<< " with species-level MAPE = " << std::to_string(s_hist.mape[sp_i])
|
<< " with species-level MAPE = " << std::to_string(s_hist.mape[sp_i])
|
||||||
<< " exceeding threshold = "
|
<< " exceeding threshold = " << std::to_string(config.mape_threshold[sp_i])
|
||||||
<< std::to_string(config.mape_threshold[sp_i])
|
|
||||||
<< ". Worst cell: ID=" << std::to_string(cell_id)
|
<< ". Worst cell: ID=" << std::to_string(cell_id)
|
||||||
<< ", species= " << species[sp_i]
|
|
||||||
<< " with MAPE=" << std::to_string(cell_mape) << std::endl;
|
<< " with MAPE=" << std::to_string(cell_mape) << std::endl;
|
||||||
|
|
||||||
return getRollbackIter();
|
return getRollbackIter();
|
||||||
@ -146,17 +147,20 @@ std::optional<uint32_t> poet::ControlModule::getRollbackTarget(
|
|||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
void poet::ControlModule::computeErrorMetrics(
|
void poet::ControlModule::computeErrorMetrics(std::vector<std::vector<double>> &reference_values,
|
||||||
std::vector<std::vector<double>> &reference_values,
|
std::vector<std::vector<double>> &surrogate_values,
|
||||||
std::vector<std::vector<double>> &surrogate_values,
|
const std::vector<std::string> &species,
|
||||||
const std::vector<std::string> &species, const uint32_t size_per_prop) {
|
const uint32_t size_per_prop) {
|
||||||
|
|
||||||
|
if (rollback_count >= 3) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const uint32_t n_cells = reference_values.size();
|
const uint32_t n_cells = reference_values.size();
|
||||||
const uint32_t n_species = species.size();
|
const uint32_t n_species = species.size();
|
||||||
const double ZERO_ABS = config.zero_abs;
|
const double ZERO_ABS = config.zero_abs;
|
||||||
|
|
||||||
CellErrorMetrics c_metrics(n_cells, n_species, global_iteration,
|
CellErrorMetrics c_metrics(n_cells, n_species, global_iteration, rollback_count);
|
||||||
rollback_count);
|
|
||||||
SpeciesErrorMetrics s_metrics(n_species, global_iteration, rollback_count);
|
SpeciesErrorMetrics s_metrics(n_species, global_iteration, rollback_count);
|
||||||
|
|
||||||
std::vector<double> species_err_sum(n_species, 0.0);
|
std::vector<double> species_err_sum(n_species, 0.0);
|
||||||
@ -194,11 +198,10 @@ void poet::ControlModule::computeErrorMetrics(
|
|||||||
c_metrics.rrmse[cell_i][sp_i] = alpha * alpha;
|
c_metrics.rrmse[cell_i][sp_i] = alpha * alpha;
|
||||||
// Log extreme MAPE values for debugging
|
// Log extreme MAPE values for debugging
|
||||||
if (c_metrics.mape[cell_i][sp_i] > 100.0) {
|
if (c_metrics.mape[cell_i][sp_i] > 100.0) {
|
||||||
std::cout << "WARNING: High MAPE detected - Cell="
|
std::cout << "WARNING: High MAPE detected - Cell=" << c_metrics.id[cell_i]
|
||||||
<< c_metrics.id[cell_i] << ", Species=" << species[sp_i]
|
<< ", Species=" << species[sp_i] << ", MAPE=" << c_metrics.mape[cell_i][sp_i]
|
||||||
<< ", MAPE=" << c_metrics.mape[cell_i][sp_i]
|
<< "%, Ref=" << ref_value << ", Sur=" << sur_value << ", Alpha=" << alpha
|
||||||
<< "%, Ref=" << ref_value << ", Sur=" << sur_value
|
<< std::endl;
|
||||||
<< ", Alpha=" << alpha << std::endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -208,19 +211,43 @@ void poet::ControlModule::computeErrorMetrics(
|
|||||||
s_metrics.rrmse[sp_i] = std::sqrt(species_sqr_sum[sp_i] / size_per_prop);
|
s_metrics.rrmse[sp_i] = std::sqrt(species_sqr_sum[sp_i] / size_per_prop);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sort cell metrics by ID
|
||||||
|
std::vector<size_t> indices(n_cells);
|
||||||
|
std::iota(indices.begin(), indices.end(), 0);
|
||||||
|
std::sort(indices.begin(), indices.end(),
|
||||||
|
[&c_metrics](size_t a, size_t b) { return c_metrics.id[a] < c_metrics.id[b]; });
|
||||||
|
|
||||||
|
// Reorder cell metrics based on sorted indices
|
||||||
|
std::vector<uint32_t> sorted_ids(n_cells);
|
||||||
|
std::vector<std::vector<double>> sorted_mape(n_cells, std::vector<double>(n_species));
|
||||||
|
std::vector<std::vector<double>> sorted_rrmse(n_cells, std::vector<double>(n_species));
|
||||||
|
|
||||||
|
for (size_t i = 0; i < n_cells; i++) {
|
||||||
|
sorted_ids[i] = c_metrics.id[indices[i]];
|
||||||
|
sorted_mape[i] = c_metrics.mape[indices[i]];
|
||||||
|
sorted_rrmse[i] = c_metrics.rrmse[indices[i]];
|
||||||
|
}
|
||||||
|
|
||||||
|
c_metrics.id = std::move(sorted_ids);
|
||||||
|
c_metrics.mape = std::move(sorted_mape);
|
||||||
|
c_metrics.rrmse = std::move(sorted_rrmse);
|
||||||
|
|
||||||
metrics_history.push_back(s_metrics);
|
metrics_history.push_back(s_metrics);
|
||||||
cell_metrics_history.push_back(c_metrics);
|
cell_metrics_history.push_back(c_metrics);
|
||||||
}
|
}
|
||||||
|
|
||||||
void poet::ControlModule::processCheckpoint(
|
void poet::ControlModule::processCheckpoint(DiffusionModule &diffusion, uint32_t ¤t_iter,
|
||||||
DiffusionModule &diffusion, uint32_t ¤t_iter,
|
const std::string &out_dir,
|
||||||
const std::string &out_dir, const std::vector<std::string> &species) {
|
const std::vector<std::string> &species) {
|
||||||
|
|
||||||
// Use max_rollbacks from config, default to 3 if not set
|
// Use max_rollbacks from config, default to 3 if not set
|
||||||
// uint32_t max_rollbacks =
|
// uint32_t max_rollbacks =
|
||||||
// (config.max_rollbacks > 0) ? config.max_rollbacks : 3;
|
// (config.max_rollbacks > 0) ? config.max_rollbacks : 3;
|
||||||
|
|
||||||
if (flush_request /* && rollback_count < 3 */) {
|
if (rollback_count >= 3) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (flush_request && rollback_count < 3) {
|
||||||
uint32_t target = getRollbackIter();
|
uint32_t target = getRollbackIter();
|
||||||
readCheckpoint(diffusion, current_iter, target, out_dir);
|
readCheckpoint(diffusion, current_iter, target, out_dir);
|
||||||
|
|
||||||
@ -228,8 +255,7 @@ void poet::ControlModule::processCheckpoint(
|
|||||||
rollback_count++;
|
rollback_count++;
|
||||||
disable_surr_counter = config.stab_interval;
|
disable_surr_counter = config.stab_interval;
|
||||||
|
|
||||||
std::cout << "Restored checkpoint " << std::to_string(target)
|
std::cout << "Restored checkpoint " << std::to_string(target) << ", surrogate disabled for "
|
||||||
<< ", surrogate disabled for "
|
|
||||||
<< std::to_string(config.stab_interval) << std::endl;
|
<< std::to_string(config.stab_interval) << std::endl;
|
||||||
} else {
|
} else {
|
||||||
writeCheckpoint(diffusion, global_iteration, out_dir);
|
writeCheckpoint(diffusion, global_iteration, out_dir);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user