Compare commits

...

4 Commits

Author SHA1 Message Date
rastogi
d21ef0b070 Updated .gitignore 2025-11-17 19:49:21 +01:00
rastogi
a34f7d14b6 Clean up .gitignore 2025-11-17 19:33:30 +01:00
rastogi
1ef61cad33 Separate packing of work package outputs into MPI buffer and add worker-side control flush logic 2025-11-17 19:25:02 +01:00
rastogi
374671e6f9 Clean up .gitignore 2025-11-17 19:17:11 +01:00
10 changed files with 360 additions and 291 deletions

13
.gitignore vendored
View File

@ -150,15 +150,4 @@ bin/*
share/
lib/
include/
# But keep these specific files
!bin/barite_fgcs_2.pqi
!bin/barite_fgcs_2.qs2
!bin/barite_fgcs_2.R
!bin/dolo/
!bin/plot/
!bin/dolo_fgcs_3.qs2
!bin/dolo_fgcs_3.R
!bin/dolo_fgcs.pqi
!bin/phreeqc_kin.dat
!bin/run_poet.sh
/.ai/

View File

@ -102,6 +102,8 @@ public:
this->base_totals = setup.base_totals;
this->ctr_file_out_dir = setup.dht_out_dir;
if (this->dht_enabled || this->interp_enabled) {
this->initializeDHT(setup.dht_size_mb, this->params.dht_species,
setup.has_het_ids);
@ -257,7 +259,7 @@ public:
std::vector<int> ai_surrogate_validity_vector;
void SetControlModule(poet::ControlModule *ctrl) { control_module = ctrl; }
void SetControlModule(poet::ControlModule *ctrl) { control = ctrl; }
void SetDhtEnabled(bool enabled) { dht_enabled = enabled; }
bool GetDhtEnabled() const { return dht_enabled; }
@ -265,7 +267,7 @@ public:
void SetInterpEnabled(bool enabled) { interp_enabled = enabled; }
bool GetInterpEnabled() const { return interp_enabled; }
void SetWarmupEnabled(bool enabled) { warmup_enabled = enabled; }
void SetStabEnabled(bool enabled) { stab_enabled = enabled; }
protected:
void initializeDHT(uint32_t size_mb,
@ -280,13 +282,13 @@ protected:
enum {
CHEM_FIELD_INIT,
//CHEM_DHT_ENABLE,
CHEM_DHT_ENABLE,
CHEM_IP_ENABLE,
CHEM_CTRL_ENABLE,
CHEM_CTRL_FLAGS,
CHEM_DHT_SIGNIF_VEC,
CHEM_DHT_SNAPS,
CHEM_DHT_READ_FILE,
//CHEM_WARMUP_PHASE, // Control flag
//CHEM_CTRL_ENABLE, // Control flag
//CHEM_IP_ENABLE,
CHEM_IP_MIN_ENTRIES,
CHEM_IP_SIGNIF_VEC,
CHEM_WORK_LOOP,
@ -295,6 +297,9 @@ protected:
CHEM_AI_BCAST_VALIDITY
};
/* broadcasted only every control iteration */
enum { DHT_ENABLE = 1u << 0, IP_ENABLE = 1u << 1, STAB_ENABLE = 1u << 2 };
enum { LOOP_WORK, LOOP_END, LOOP_CTRL };
enum {
@ -378,9 +383,11 @@ protected:
void BCastStringVec(std::vector<std::string> &io);
int packResultsIntoBuffer(std::vector<double> &mpi_buffer, int base_count,
const WorkPackage &wp,
const WorkPackage &wp_control);
void copyPkgs(const WorkPackage &wp, std::vector<double> &mpi_buffer,
std::size_t offset = 0);
void copyCtrlPkgs(const WorkPackage &pqc_wp, const WorkPackage &surr_wp,
std::vector<double> &mpi_bufffer, int &count);
int comm_size, comm_rank;
MPI_Comm group_comm;
@ -400,7 +407,7 @@ protected:
bool ai_surrogate_enabled{false};
static constexpr uint32_t BUFFER_OFFSET = 6;
static constexpr uint32_t BUFFER_OFFSET = 5;
inline void ChemBCast(void *buf, int count, MPI_Datatype datatype) const {
MPI_Bcast(buf, count, datatype, 0, this->group_comm);
@ -410,6 +417,22 @@ protected:
ChemBCast(&type, 1, MPI_INT);
}
std::string ctr_file_out_dir;
inline int buildControlPacket(bool dht, bool interp, bool stab) {
int flags = 0;
if (dht)
flags |= DHT_ENABLE;
if (interp)
flags |= IP_ENABLE;
if (stab)
flags |= STAB_ENABLE;
return flags;
}
inline bool hasFlag(int flags, int type) { return (flags & type) != 0; }
double simtime = 0.;
double idle_t = 0.;
double seq_t = 0.;
@ -437,14 +460,12 @@ protected:
std::unique_ptr<PhreeqcRunner> pqc_runner;
poet::ControlModule *control_module = nullptr;
poet::ControlModule *control = nullptr;
std::vector<double> mpi_surr_buffer;
bool control_enabled{false};
bool warmup_enabled{false};
// std::vector<double> sur_shuffled;
bool stab_enabled{false};
};
} // namespace poet

View File

@ -280,10 +280,11 @@ inline void poet::ChemistryModule::MasterSendPkgs(
// current work package start location in field
send_buffer[end_of_wp + 4] = wp_start_index;
// control flags (bitmask)
int flags = (this->interp_enabled ? 1 : 0) | (this->dht_enabled ? 2 : 0) |
(this->warmup_enabled ? 4 : 0) |
(this->control_enabled ? 8 : 0);
send_buffer[end_of_wp + 5] = static_cast<double>(flags);
/* int flags = (this->interp_enabled ? 1 : 0) | (this->dht_enabled ? 2 :
0) | (this->warmup_enabled ? 4 : 0) | (this->control_enabled ? 8 : 0);
send_buffer[end_of_wp + 5] = static_cast<double>(flags);
*/
/* ATTENTION Worker p has rank p+1 */
// MPI_Send(send_buffer, end_of_wp + BUFFER_OFFSET, MPI_DOUBLE, p + 1,
@ -445,16 +446,29 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
MPI_INT);
}
/* broadcast control state once every iteration */
ftype = CHEM_CTRL_ENABLE;
PropagateFunctionType(ftype);
int ctrl =
(this->control_enabled = this->control->getControlIntervalEnabled()) ? 1
: 0;
ChemBCast(&ctrl, 1, MPI_INT);
if (control->shouldBcastFlags()) {
int ftype = CHEM_CTRL_FLAGS;
PropagateFunctionType(ftype);
uint32_t ctrl_flags = buildControlPacket(
this->dht_enabled, this->interp_enabled, this->stab_enabled);
ChemBCast(&ctrl_flags, 1, MPI_INT);
this->mpi_surr_buffer.assign(this->n_cells * this->prop_count, 0.0);
}
ftype = CHEM_WORK_LOOP;
PropagateFunctionType(ftype);
MPI_Barrier(this->group_comm);
this->control_enabled = this->control_module->getControlIntervalEnabled();
if (this->control_enabled) {
this->mpi_surr_buffer.assign(this->n_cells * this->prop_count, 0.0);
}
static uint32_t iteration = 0;
/* start time measurement of sequential part */
@ -466,7 +480,10 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count,
wp_sizes_vector.size());
//this->mpi_surr_buffer.resize(mpi_buffer.size());
// Only resize surrogate buffer if control is enabled
if (this->control_enabled) {
this->mpi_surr_buffer.resize(mpi_buffer.size());
}
/* setup local variables */
pkg_to_send = wp_sizes_vector.size();
@ -541,9 +558,10 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
}
metrics_a = MPI_Wtime();
control_module->computeSpeciesErrorMetrics(out_vec, sur_unshuffled,
this->n_cells);
control->computeErrorMetrics(out_vec, sur_unshuffled, this->n_cells,
prop_names);
metrics_b = MPI_Wtime();
this->metrics_t += metrics_b - metrics_a;
}
@ -556,9 +574,15 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
/* end time measurement of whole chemistry simulation */
std::optional<uint32_t> target = std::nullopt;
if (this->control_enabled) {
target = control->getRollbackTarget(prop_names);
}
int flush = (this->control_enabled && target.has_value()) ? 1 : 0;
/* advise workers to end chemistry iteration */
for (int i = 1; i < this->comm_size; i++) {
MPI_Send(NULL, 0, MPI_DOUBLE, i, LOOP_END, this->group_comm);
MPI_Send(&flush, 1, MPI_INT, i, LOOP_END, this->group_comm);
}
this->simtime += dt;

View File

@ -82,14 +82,22 @@ void poet::ChemistryModule::WorkerLoop() {
std::cout << "Interp_enabled is " << this->interp_enabled << std::endl;
break;
}
*/
case CHEM_CTRL_ENABLE: {
int control_flag = 0;
ChemBCast(&control_flag, 1, MPI_INT);
this->control_enabled = (control_flag == 1);
std::cout << "Control_enabled is " << this->control_enabled << std::endl;
int ctrl = 0;
ChemBCast(&ctrl, 1, MPI_INT);
this->control_enabled = (ctrl == 1);
break;
}
*/
case CHEM_CTRL_FLAGS: {
int flags = 0;
ChemBCast(&flags, 1, MPI_INT);
this->dht_enabled = hasFlag(flags, DHT_ENABLE);
this->interp_enabled = hasFlag(flags, IP_ENABLE);
this->stab_enabled = hasFlag(flags, STAB_ENABLE);
break;
}
case CHEM_WORK_LOOP: {
WorkerProcessPkgs(timings, iteration);
break;
@ -146,6 +154,38 @@ void poet::ChemistryModule::WorkerProcessPkgs(struct worker_s &timings,
}
}
}
void poet::ChemistryModule::copyPkgs(const WorkPackage &wp,
std::vector<double> &mpi_buffer,
std::size_t offset) {
for (std::size_t wp_i = 0; wp_i < wp.size; wp_i++) {
std::copy(wp.output[wp_i].begin(), wp.output[wp_i].end(),
mpi_buffer.begin() + offset + this->prop_count * wp_i);
}
}
void poet::ChemistryModule::copyCtrlPkgs(const WorkPackage &pqc_wp,
const WorkPackage &surr_wp,
std::vector<double> &mpi_buffer,
int &count) {
std::size_t wp_offset = surr_wp.size * this->prop_count;
mpi_buffer.resize(count + wp_offset);
copyPkgs(pqc_wp, mpi_buffer);
// s_curr_wp only contains the interpolated data
// copy surrogate output after the the pqc output, mpi_buffer[pqc][interp]
for (std::size_t wp_i = 0; wp_i < surr_wp.size; wp_i++) {
if (surr_wp.mapping[wp_i] != CHEM_PQC) {
// only copy if surrogate was used
copyPkgs(surr_wp, mpi_buffer, wp_offset);
} else {
// if pqc was used, copy pqc results again
copyPkgs(pqc_wp, mpi_buffer, wp_offset);
}
}
count += wp_offset;
}
void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
int double_count,
@ -190,17 +230,21 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
wp_start_index = mpi_buffer[count + 4];
// read packed control flags
/*
flags = static_cast<int>(mpi_buffer[count + 5]);
this->interp_enabled = (flags & 1) != 0;
this->dht_enabled = (flags & 2) != 0;
this->warmup_enabled = (flags & 4) != 0;
this->control_enabled = (flags & 8) != 0;
*/
/*std::cout << "warmup_enabled is " << warmup_enabled << ", control_enabled is
"
<< control_enabled << ", dht_enabled is "
<< dht_enabled << ", interp_enabled is " << interp_enabled
<< std::endl;*/
/*
std::cout << "warmup_enabled is " << stab_enabled << ", control_enabled is "
<< control_enabled << ", dht_enabled is " << dht_enabled
<< ", interp_enabled is " << interp_enabled << std::endl;
*/
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
s_curr_wp.input[wp_i] =
@ -209,7 +253,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
}
// std::cout << this->comm_rank << ":" << counter++ << std::endl;
if (dht_enabled || interp_enabled || warmup_enabled) {
if (dht_enabled || interp_enabled || stab_enabled) {
dht->prepareKeys(s_curr_wp.input, dt);
}
@ -259,39 +303,11 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
if (control_enabled) {
ctrl_start = MPI_Wtime();
std::size_t sur_wp_offset = s_curr_wp.size * this->prop_count;
mpi_buffer.resize(count + sur_wp_offset);
for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) {
std::copy(s_curr_wp_control.output[wp_i].begin(),
s_curr_wp_control.output[wp_i].end(),
mpi_buffer.begin() + this->prop_count * wp_i);
}
// s_curr_wp only contains the interpolated data
// copy surrogate output after the the pqc output, mpi_buffer[pqc][interp]
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
// only copy if surrogate was used
if (s_curr_wp.mapping[wp_i] != CHEM_PQC) {
std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(),
mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i);
} else {
// if pqc was used, copy pqc results again
std::copy(s_curr_wp_control.output[wp_i].begin(),
s_curr_wp_control.output[wp_i].end(),
mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i);
}
}
count += sur_wp_offset;
copyCtrlPkgs(s_curr_wp_control, s_curr_wp, mpi_buffer, count);
ctrl_end = MPI_Wtime();
timings.ctrl_t += ctrl_end - ctrl_start;
} else {
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(),
mpi_buffer.begin() + this->prop_count * wp_i);
}
copyPkgs(s_curr_wp, mpi_buffer);
}
/* send results to master */
@ -301,13 +317,13 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
MPI_Isend(mpi_buffer.data(), count, MPI_DOUBLE, 0, mpi_tag, MPI_COMM_WORLD,
&send_req);
if (dht_enabled || interp_enabled || warmup_enabled) {
if (dht_enabled || interp_enabled || stab_enabled) {
/* write results to DHT */
dht_fill_start = MPI_Wtime();
dht->fillDHT(control_enabled ? s_curr_wp_control : s_curr_wp);
dht_fill_end = MPI_Wtime();
if (interp_enabled || warmup_enabled) {
if (interp_enabled || stab_enabled) {
interp->writePairs();
}
timings.dht_fill += dht_fill_end - dht_fill_start;
@ -317,10 +333,20 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
MPI_Wait(&send_req, MPI_STATUS_IGNORE);
}
void poet::ChemistryModule::WorkerPostIter(MPI_Status &prope_status,
void poet::ChemistryModule::WorkerPostIter(MPI_Status &probe_status,
uint32_t iteration) {
MPI_Recv(NULL, 0, MPI_DOUBLE, 0, LOOP_END, this->group_comm,
MPI_STATUS_IGNORE);
int size, flush = 0;
MPI_Get_count(&probe_status, MPI_INT, &size);
if (size == 1) {
MPI_Recv(&flush, size, MPI_INT, probe_status.MPI_SOURCE, LOOP_END,
this->group_comm, MPI_STATUS_IGNORE);
} else {
MPI_Recv(NULL, 0, MPI_INT, probe_status.MPI_SOURCE, LOOP_END,
this->group_comm, MPI_STATUS_IGNORE);
}
if (this->dht_enabled) {
dht_hits.push_back(dht->getHits());
@ -346,7 +372,7 @@ void poet::ChemistryModule::WorkerPostIter(MPI_Status &prope_status,
const auto max_mean_idx =
DHT_get_used_idx_factor(this->interp->getDHTObject(), 1);
if (max_mean_idx >= 2) {
if (max_mean_idx >= 2 || flush) {
DHT_flush(this->interp->getDHTObject());
DHT_flush(this->dht->getDHT());
if (this->comm_rank == 2) {

View File

@ -4,178 +4,166 @@
#include "IO/StatsIO.hpp"
#include <cmath>
void poet::ControlModule::updateControlIteration(const uint32_t &iter,
const bool &dht_enabled,
const bool &interp_enabled) {
poet::ControlModule::ControlModule(const ControlConfig &config_,
ChemistryModule *chem_)
: config(config_), chem(chem_) {
assert(chem && "ChemistryModule pointer must not be null");
}
void poet::ControlModule::beginIteration(const uint32_t &iter,
const bool &dht_enabled,
const bool &interp_enabled) {
/* dht_enabled and inter_enabled are user settings set before startig the
* simulation*/
double prep_a, prep_b;
prep_a = MPI_Wtime();
if (control_interval == 0) {
if (config.control_interval == 0) {
control_interval_enabled = false;
return;
}
global_iteration = iter;
initiateWarmupPhase(dht_enabled, interp_enabled);
updateStabilizationPhase(dht_enabled, interp_enabled);
control_interval_enabled =
(control_interval > 0 && iter % control_interval == 0);
(config.control_interval > 0 && (iter % config.control_interval == 0));
if (control_interval_enabled) {
MSG("[Control] Control interval enabled at iteration " +
std::to_string(iter));
}
prep_b = MPI_Wtime();
this->prep_t += prep_b - prep_a;
}
void poet::ControlModule::initiateWarmupPhase(bool dht_enabled,
bool interp_enabled) {
void poet::ControlModule::updateStabilizationPhase(bool dht_enabled,
bool interp_enabled) {
if (rollback_enabled) {
if (disable_surr_counter > 0) {
--disable_surr_counter;
MSG("Rollback counter: " + std::to_string(disable_surr_counter));
} else {
rollback_enabled = false;
}
}
// user requested DHT/INTEP? keep them disabled but enable warmup-phase so
if (global_iteration <= control_interval || rollback_enabled) {
chem->SetWarmupEnabled(true);
if (global_iteration <= config.control_interval || rollback_enabled) {
chem->SetStabEnabled(true);
chem->SetDhtEnabled(false);
chem->SetInterpEnabled(false);
MSG("Warmup enabled until next control interval at iteration " +
std::to_string(control_interval) + ".");
if (rollback_enabled) {
if (sur_disabled_counter > 0) {
--sur_disabled_counter;
MSG("Rollback counter: " + std::to_string(sur_disabled_counter));
} else {
rollback_enabled = false;
}
}
return;
}
chem->SetWarmupEnabled(false);
chem->SetStabEnabled(false);
chem->SetDhtEnabled(dht_enabled);
chem->SetInterpEnabled(interp_enabled);
}
void poet::ControlModule::applyControlLogic(ChemistryModule &chem,
uint32_t &iter) {
if (!control_interval_enabled) {
return;
}
writeCheckpointAndMetrics(chem, iter);
if (checkAndRollback(chem, iter) && rollback_count < 3) {
rollback_enabled = true;
rollback_count++;
sur_disabled_counter = control_interval;
MSG("Interpolation disabled for the next " +
std::to_string(control_interval) + ".");
}
}
void poet::ControlModule::writeCheckpointAndMetrics(ChemistryModule &chem,
uint32_t iter) {
double w_check_a, w_check_b, stats_a, stats_b;
MSG("Writing checkpoint of iteration " + std::to_string(iter));
void poet::ControlModule::writeCheckpoint(uint32_t &iter,
const std::string &out_dir) {
double w_check_a, w_check_b;
w_check_a = MPI_Wtime();
write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
{.field = chem.getField(), .iteration = iter});
{.field = chem->getField(), .iteration = iter});
w_check_b = MPI_Wtime();
this->w_check_t += w_check_b - w_check_a;
last_checkpoint_written = iter;
}
void poet::ControlModule::readCheckpoint(uint32_t &current_iter,
uint32_t rollback_iter,
const std::string &out_dir) {
double r_check_a, r_check_b;
r_check_a = MPI_Wtime();
Checkpoint_s checkpoint_read{.field = chem->getField()};
read_checkpoint(out_dir,
"checkpoint" + std::to_string(rollback_iter) + ".hdf5",
checkpoint_read);
current_iter = checkpoint_read.iteration;
r_check_b = MPI_Wtime();
r_check_t += r_check_b - r_check_a;
}
void poet::ControlModule::writeErrorMetrics(
const std::string &out_dir, const std::vector<std::string> &species) {
double stats_a, stats_b;
stats_a = MPI_Wtime();
writeStatsToCSV(metricsHistory, species_names, out_dir, "stats_overview");
writeStatsToCSV(metrics_history, species, out_dir, "metrics_overview");
stats_b = MPI_Wtime();
this->stats_t += stats_b - stats_a;
}
bool poet::ControlModule::checkAndRollback(ChemistryModule &chem,
uint32_t &iter) {
uint32_t poet::ControlModule::getRollbackIter() {
uint32_t last_iter = ((global_iteration - 1) / config.checkpoint_interval) *
config.checkpoint_interval;
uint32_t rollback_iter = (last_iter <= last_checkpoint_written)
? last_iter
: last_checkpoint_written;
return rollback_iter;
}
std::optional<uint32_t> poet::ControlModule::getRollbackTarget(
const std::vector<std::string> &species) {
double r_check_a, r_check_b;
if (metricsHistory.empty()) {
MSG("No error history yet; skipping rollback check.");
return false;
if (metrics_history.empty()) {
MSG("No error history yet, skipping rollback check.");
flush_request = false;
return std::nullopt;
}
if (rollback_count > 3) {
MSG("Rollback limit reached, skipping rollback.");
flush_request = false;
return std::nullopt;
}
const auto &mape = metricsHistory.back().mape;
for (uint32_t i = 0; i < species_names.size(); ++i) {
if (mape[i] == 0) {
const auto &mape = metrics_history.back().mape;
for (uint32_t i = 0; i < species.size(); ++i) {
// skip Charge
if (i == 4 || mape[i] == 0) {
continue;
}
if (mape[i] > config.mape_threshold[i]) {
if (last_checkpoint_written == 0) {
MSG(" Threshold exceeded but no checkpoint exists yet.");
return std::nullopt;
}
if (mape[i] > mape_threshold[i]) {
uint32_t rollback_iter =
((iter - 1) / checkpoint_interval) * checkpoint_interval;
flush_request = true;
MSG("[THRESHOLD EXCEEDED] " + species_names[i] +
MSG("T hreshold exceeded " + species[i] +
" has MAPE = " + std::to_string(mape[i]) +
" exceeding threshold = " + std::to_string(mape_threshold[i]) +
" → rolling back to iteration " + std::to_string(rollback_iter));
r_check_a = MPI_Wtime();
Checkpoint_s checkpoint_read{.field = chem.getField()};
read_checkpoint(out_dir,
"checkpoint" + std::to_string(rollback_iter) + ".hdf5",
checkpoint_read);
iter = checkpoint_read.iteration;
r_check_b = MPI_Wtime();
r_check_t += r_check_b - r_check_a;
return true;
" exceeding threshold = " + std::to_string(config.mape_threshold[i]));
return getRollbackIter();
}
}
MSG("All species are within their MAPE thresholds.");
return false;
flush_request = false;
return std::nullopt;
}
void poet::ControlModule::computeSpeciesErrorMetrics(
void poet::ControlModule::computeErrorMetrics(
const std::vector<double> &reference_values,
const std::vector<double> &surrogate_values, const uint32_t size_per_prop) {
const std::vector<double> &surrogate_values, const uint32_t size_per_prop,
const std::vector<std::string> &species) {
SpeciesErrorMetrics metrics(this->species_names.size(), global_iteration,
rollback_count);
SpeciesErrorMetrics metrics(species.size(), global_iteration, rollback_count);
if (reference_values.size() != surrogate_values.size()) {
MSG(" Reference and surrogate vectors differ in size: " +
std::to_string(reference_values.size()) + " vs " +
std::to_string(surrogate_values.size()));
return;
}
const std::size_t expected =
static_cast<std::size_t>(this->species_names.size()) * size_per_prop;
if (reference_values.size() < expected) {
std::cerr << "[CTRL ERROR] input vectors too small: expected >= "
<< expected << " entries, got " << reference_values.size()
<< "\n";
return;
}
for (uint32_t i = 0; i < this->species_names.size(); ++i) {
for (uint32_t i = 0; i < species.size(); ++i) {
double err_sum = 0.0;
double sqr_err_sum = 0.0;
uint32_t base_idx = i * size_per_prop;
int count = 0;
for (uint32_t j = 0; j < size_per_prop; ++j) {
const double ref_value = reference_values[base_idx + j];
const double sur_value = surrogate_values[base_idx + j];
const double ZERO_ABS = 1e-13;
const double ZERO_ABS = config.zero_abs;
if (std::isnan(ref_value) || std::isnan(sur_value)) {
continue;
}
if (std::abs(ref_value) < ZERO_ABS) {
if (std::abs(sur_value) >= ZERO_ABS) {
err_sum += 1.0;
@ -192,5 +180,35 @@ void poet::ControlModule::computeSpeciesErrorMetrics(
metrics.mape[i] = 100.0 * (err_sum / size_per_prop);
metrics.rrmse[i] = std::sqrt(sqr_err_sum / size_per_prop);
}
metricsHistory.push_back(metrics);
metrics_history.push_back(metrics);
}
void poet::ControlModule::processCheckpoint(
uint32_t &current_iter, const std::string &out_dir,
const std::vector<std::string> &species) {
if (!control_interval_enabled)
return;
if (flush_request && rollback_count < 3) {
uint32_t target = getRollbackIter();
readCheckpoint(current_iter, target, out_dir);
rollback_enabled = true;
rollback_count++;
disable_surr_counter = config.control_interval;
MSG("Restored checkpoint " + std::to_string(target) +
", surrogates disabled for " + std::to_string(config.control_interval));
} else {
writeCheckpoint(global_iteration, out_dir);
}
}
bool poet::ControlModule::shouldBcastFlags() const {
if (global_iteration == 1 ||
global_iteration % config.control_interval == 1) {
return true;
}
return false;
}

View File

@ -4,8 +4,8 @@
#include "Base/Macros.hpp"
#include "Chemistry/ChemistryModule.hpp"
#include "poet.hpp"
#include <cstdint>
#include <optional>
#include <string>
#include <vector>
@ -13,104 +13,93 @@ namespace poet {
class ChemistryModule;
struct ControlConfig {
uint32_t control_interval = 0;
uint32_t checkpoint_interval = 0;
double zero_abs = 0.0;
std::vector<double> mape_threshold;
};
struct SpeciesErrorMetrics {
std::vector<double> mape;
std::vector<double> rrmse;
uint32_t iteration = 0;
uint32_t rollback_count = 0;
SpeciesErrorMetrics(uint32_t n_species, uint32_t iter, uint32_t rb_count)
: mape(n_species, 0.0), rrmse(n_species, 0.0), iteration(iter),
rollback_count(rb_count) {}
};
class ControlModule {
public:
/* Control configuration*/
explicit ControlModule(const ControlConfig &config, ChemistryModule *chem);
// std::uint32_t global_iter = 0;
// std::uint32_t sur_disabled_counter = 0;
// std::uint32_t rollback_counter = 0;
void beginIteration(const uint32_t &iter, const bool &dht_enabled,
const bool &interp_enaled);
void updateControlIteration(const uint32_t &iter, const bool &dht_enabled,
const bool &interp_enaled);
void writeCheckpoint(uint32_t &iter, const std::string &out_dir);
void initiateWarmupPhase(bool dht_enabled, bool interp_enabled);
void writeErrorMetrics(const std::string &out_dir,
const std::vector<std::string> &species);
bool checkAndRollback(ChemistryModule &chem, uint32_t &iter);
std::optional<uint32_t> getRollbackTarget();
struct SpeciesErrorMetrics {
std::vector<double> mape;
std::vector<double> rrmse;
uint32_t iteration; // iterations in simulation after rollbacks
uint32_t rollback_count;
void computeErrorMetrics(const std::vector<double> &reference_values,
const std::vector<double> &surrogate_values,
const uint32_t size_per_prop,
const std::vector<std::string> &species);
SpeciesErrorMetrics(uint32_t species_count, uint32_t iter, uint32_t counter)
: mape(species_count, 0.0), rrmse(species_count, 0.0), iteration(iter),
rollback_count(counter) {}
};
void processCheckpoint(uint32_t &current_iter,
const std::string &out_dir,
const std::vector<std::string> &species);
void computeSpeciesErrorMetrics(const std::vector<double> &reference_values,
const std::vector<double> &surrogate_values,
const uint32_t size_per_prop);
std::vector<SpeciesErrorMetrics> metricsHistory;
struct ControlSetup {
std::string out_dir;
std::uint32_t checkpoint_interval;
std::uint32_t control_interval;
std::vector<std::string> species_names;
std::vector<double> mape_threshold;
};
void enableControlLogic(const ControlSetup &setup) {
this->out_dir = setup.out_dir;
this->checkpoint_interval = setup.checkpoint_interval;
this->control_interval = setup.control_interval;
this->species_names = setup.species_names;
this->mape_threshold = setup.mape_threshold;
}
std::optional<uint32_t>
getRollbackTarget(const std::vector<std::string> &species);
bool shouldBcastFlags() const;
bool getControlIntervalEnabled() const {
return this->control_interval_enabled;
}
void applyControlLogic(ChemistryModule &chem, uint32_t &iter);
void writeCheckpointAndMetrics(ChemistryModule &chem, uint32_t iter);
auto getGlobalIteration() const noexcept { return global_iteration; }
void setChemistryModule(poet::ChemistryModule *c) { chem = c; }
auto getControlInterval() const { return this->control_interval; }
std::vector<double> getMapeThreshold() const { return this->mape_threshold; }
bool getFlushRequest() const { return flush_request; }
void clearFlushRequest() { flush_request = false; }
/* Profiling getters */
auto getUpdateCtrlLogicTime() const { return this->prep_t; }
auto getWriteCheckpointTime() const { return this->w_check_t; }
auto getReadCheckpointTime() const { return this->r_check_t; }
auto getWriteMetricsTime() const { return this->stats_t; }
double getUpdateCtrlLogicTime() const { return prep_t; }
double getWriteCheckpointTime() const { return w_check_t; }
double getReadCheckpointTime() const { return r_check_t; }
double getWriteMetricsTime() const { return stats_t; }
private:
bool rollback_enabled = false;
bool control_interval_enabled = false;
void updateStabilizationPhase(bool dht_enabled, bool interp_enabled);
poet::ChemistryModule *chem = nullptr;
void readCheckpoint(uint32_t &current_iter,
uint32_t rollback_iter, const std::string &out_dir);
uint32_t getRollbackIter();
ControlConfig config;
ChemistryModule *chem = nullptr;
std::uint32_t checkpoint_interval = 0;
std::uint32_t control_interval = 0;
std::uint32_t global_iteration = 0;
std::uint32_t rollback_count = 0;
std::uint32_t sur_disabled_counter = 0;
std::vector<double> mape_threshold;
std::uint32_t disable_surr_counter = 0;
std::uint32_t last_checkpoint_written = 0;
std::vector<std::string> species_names;
std::string out_dir;
bool rollback_enabled = false;
bool control_interval_enabled = false;
bool flush_request = false;
bool bcast_flags = false;
std::vector<SpeciesErrorMetrics> metrics_history;
double prep_t = 0.;
double r_check_t = 0.;
double w_check_t = 0.;
double stats_t = 0.;
/* Buffer for shuffled surrogate data */
std::vector<double> sur_shuffled;
};
} // namespace poet

View File

@ -7,7 +7,7 @@
namespace poet
{
void writeStatsToCSV(const std::vector<ControlModule::SpeciesErrorMetrics> &all_stats,
void writeStatsToCSV(const std::vector<SpeciesErrorMetrics> &all_stats,
const std::vector<std::string> &species_names,
const std::string &out_dir,
const std::string &filename)

View File

@ -3,7 +3,7 @@
namespace poet
{
void writeStatsToCSV(const std::vector<ControlModule::SpeciesErrorMetrics> &all_stats,
void writeStatsToCSV(const std::vector<SpeciesErrorMetrics> &all_stats,
const std::vector<std::string> &species_names,
const std::string &out_dir,
const std::string &filename);

View File

@ -255,6 +255,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval"));
params.mape_threshold = Rcpp::as<std::vector<double>>(
global_rt_setup->operator[]("mape_threshold"));
params.zero_abs = Rcpp::as<double>(global_rt_setup->operator[]("zero_abs"));
} catch (const std::exception &e) {
ERRMSG("Error while parsing R scripts: " + std::string(e.what()));
return ParseRet::PARSER_ERROR;
@ -300,7 +301,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
double dSimTime{0};
for (uint32_t iter = 1; iter < maxiter + 1; iter++) {
control.updateControlIteration(iter, params.use_dht, params.use_interp);
control.beginIteration(iter, params.use_dht, params.use_interp);
double start_t = MPI_Wtime();
@ -410,7 +411,10 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +
std::to_string(maxiter));
control.applyControlLogic(chem, iter);
if (control.getControlIntervalEnabled()) {
control.processCheckpoint(iter, params.out_dir, chem.getField().GetProps());
control.writeErrorMetrics(params.out_dir, chem.getField().GetProps());
}
// MSG();
} // END SIMULATION LOOP
@ -435,16 +439,13 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
ctrl_profiling["write_stats"] = control.getWriteMetricsTime();
ctrl_profiling["ctrl_logic_master"] = control.getUpdateCtrlLogicTime();
ctrl_profiling["recv_data_master"] = chem.GetMasterRecvCtrlDataTime();
ctrl_profiling["worker"] =
Rcpp::wrap(chem.GetWorkerControlTimings());
ctrl_profiling["worker"] = Rcpp::wrap(chem.GetWorkerControlTimings());
//if (params.use_dht) {
chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits());
chem_profiling["dht_evictions"] = Rcpp::wrap(chem.GetWorkerDHTEvictions());
chem_profiling["dht_get_time"] = Rcpp::wrap(chem.GetWorkerDHTGetTimings());
chem_profiling["dht_fill_time"] =
Rcpp::wrap(chem.GetWorkerDHTFillTimings());
// if (params.use_dht) {
chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits());
chem_profiling["dht_evictions"] = Rcpp::wrap(chem.GetWorkerDHTEvictions());
chem_profiling["dht_get_time"] = Rcpp::wrap(chem.GetWorkerDHTGetTimings());
chem_profiling["dht_fill_time"] = Rcpp::wrap(chem.GetWorkerDHTFillTimings());
//}
if (params.use_interp) {
@ -607,10 +608,10 @@ int main(int argc, char *argv[]) {
ChemistryModule chemistry(run_params.work_package_size,
init_list.getChemistryInit(), MPI_COMM_WORLD);
ControlModule control;
chemistry.SetControlModule(&control);
control.setChemistryModule(&chemistry);
// ControlModule control;
// chemistry.SetControlModule(&control);
// control.setChemistryModule(&chemistry);
const ChemistryModule::SurrogateSetup surr_setup = {
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
@ -628,15 +629,6 @@ int main(int argc, char *argv[]) {
chemistry.masterEnableSurrogates(surr_setup);
const ControlModule::ControlSetup ctrl_setup = {
run_params.out_dir, // added
run_params.checkpoint_interval,
run_params.control_interval,
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
run_params.mape_threshold};
control.enableControlLogic(ctrl_setup);
if (MY_RANK > 0) {
chemistry.WorkerLoop();
} else {
@ -680,7 +672,16 @@ int main(int argc, char *argv[]) {
chemistry.masterSetField(init_list.getInitialGrid());
Rcpp::List profiling = RunMasterLoop(R, run_params, diffusion, chemistry, control);
ControlConfig config(run_params.control_interval,
run_params.checkpoint_interval, run_params.zero_abs,
run_params.mape_threshold);
ControlModule control(config, &chemistry);
chemistry.SetControlModule(&control);
Rcpp::List profiling =
RunMasterLoop(R, run_params, diffusion, chemistry, control);
MSG("finished simulation loop");

View File

@ -54,6 +54,7 @@ struct RuntimeParameters {
std::uint32_t checkpoint_interval = 0;
std::uint32_t control_interval = 0;
std::vector<double> mape_threshold;
double zero_abs = 0.0;
static constexpr std::uint32_t WORK_PACKAGE_SIZE_DEFAULT = 32;
std::uint32_t work_package_size = WORK_PACKAGE_SIZE_DEFAULT;