feat: Add control logic for surrogate validation with selective PHREEQC re-execution, error metrics (MAPE/RRMSE), and automatic checkpoint/rollback

This commit is contained in:
rastogi 2025-11-06 20:55:02 +01:00
parent d0faa8e6c3
commit 1f70dc4070
20 changed files with 598 additions and 628 deletions

9
.gitignore vendored
View File

@ -152,9 +152,12 @@ lib/
include/ include/
# But keep these specific files # But keep these specific files
!bin/barite_fgcs_2.pqi !bin/compare_qs2.R
!bin/barite_fgcs_2.qs2 !bin/barite_fgcs_3.pqi
!bin/barite_fgcs_2.R !bin/barite_fgcs_4_rt.R
!bin/barite_fgcs_4.R
!bin/barite_fgcs_4.qs2
!bin/db_barite.dat
!bin/dol.pqi !bin/dol.pqi
!bin/dolo_fgcs_3.qs2 !bin/dolo_fgcs_3.qs2
!bin/dolo_fgcs_3.R !bin/dolo_fgcs_3.R

View File

@ -115,21 +115,21 @@ setup <- list(
Chemistry = chemistry_setup # Parameters related to the chemistry process Chemistry = chemistry_setup # Parameters related to the chemistry process
) )
iterations <- 20 iterations <- 200
dt <- 100 dt <- 100
checkpoint_interval <- 10 checkpoint_interval <- 10
control_interval <- 10 #control_interval <- 10
mape_threshold <- rep(3.5e-3, 13) mape_threshold <- rep(3.5e-3, 13)
ctrl_cell_ids <- seq(0, (rows*cols)/2 - 1, by = rows+1) ctrl_cell_ids <- seq(0, (rows*cols)/2 - 1, by = rows+1)
#out_save <- seq(50, iterations, by = 50) out_save <- seq(50, iterations, by = 50)
list( list(
timesteps = rep(dt, iterations), timesteps = rep(dt, iterations),
store_result = TRUE, store_result = FALSE,
#out_save = out_save, out_save = out_save,
checkpoint_interval = checkpoint_interval, checkpoint_interval = checkpoint_interval,
control_interval = control_interval, #control_interval = control_interval,
mape_threshold = mape_threshold, mape_threshold = mape_threshold,
ctrl_cell_ids = ctrl_cell_ids ctrl_cell_ids = ctrl_cell_ids
) )

Binary file not shown.

BIN
bin/poet

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -33,6 +33,7 @@ add_library(POETLib
Chemistry/SurrogateModels/HashFunctions.cpp Chemistry/SurrogateModels/HashFunctions.cpp
Chemistry/SurrogateModels/InterpolationModule.cpp Chemistry/SurrogateModels/InterpolationModule.cpp
Chemistry/SurrogateModels/ProximityHashTable.cpp Chemistry/SurrogateModels/ProximityHashTable.cpp
Control/ControlModule.cpp
) )
set(POET_TUG_APPROACH "Implicit" CACHE STRING "tug numerical approach to use") set(POET_TUG_APPROACH "Implicit" CACHE STRING "tug numerical approach to use")

View File

@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include <mpi.h> #include <mpi.h>
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
namespace poet { namespace poet {
@ -259,13 +260,17 @@ public:
void SetControlModule(poet::ControlModule *ctrl) { control_module = ctrl; } void SetControlModule(poet::ControlModule *ctrl) { control_module = ctrl; }
void SetDhtEnabled(bool enabled) { dht_enabled = enabled; } void SetDhtEnabled(bool enabled) { this->dht_enabled = enabled; }
bool GetDhtEnabled() const { return dht_enabled; } bool GetDhtEnabled() const { return this->dht_enabled; }
void SetInterpEnabled(bool enabled) { interp_enabled = enabled; } void SetInterpEnabled(bool enabled) { this->interp_enabled = enabled; }
bool GetInterpEnabled() const { return interp_enabled; } bool GetInterpEnabled() const { return interp_enabled; }
void SetWarmupEnabled(bool enabled) { warmup_enabled = enabled; } void SetWarmupEnabled(bool enabled) { this->warmup_enabled = enabled; }
void SetControlCellIds(const std::vector<uint32_t> &ids) {
this->ctrl_cell_ids = std::unordered_set<uint32_t>(ids.begin(), ids.end());
}
protected: protected:
void initializeDHT(uint32_t size_mb, void initializeDHT(uint32_t size_mb,
@ -323,6 +328,7 @@ protected:
double dht_fill = 0.; double dht_fill = 0.;
double idle_t = 0.; double idle_t = 0.;
double ctrl_t = 0.; double ctrl_t = 0.;
double ctrl_phreeqc_t = 0.;
}; };
struct worker_info_s { struct worker_info_s {
@ -364,6 +370,10 @@ protected:
void WorkerRunWorkPackage(WorkPackage &work_package, double dSimTime, void WorkerRunWorkPackage(WorkPackage &work_package, double dSimTime,
double dTimestep); double dTimestep);
void ProcessControlWorkPackage(std::vector<std::vector<double>> &input,
double current_sim_time, double dt,
struct worker_s &timings);
std::vector<uint32_t> CalculateWPSizesVector(uint32_t n_cells, std::vector<uint32_t> CalculateWPSizesVector(uint32_t n_cells,
uint32_t wp_size) const; uint32_t wp_size) const;
std::vector<double> shuffleField(const std::vector<double> &in_field, std::vector<double> shuffleField(const std::vector<double> &in_field,
@ -444,7 +454,8 @@ protected:
bool control_enabled{false}; bool control_enabled{false};
bool warmup_enabled{false}; bool warmup_enabled{false};
// std::vector<double> sur_shuffled; std::unordered_set<uint32_t> ctrl_cell_ids;
std::vector<std::vector<double>> control_batch;
}; };
} // namespace poet } // namespace poet

View File

@ -341,32 +341,33 @@ inline void poet::ChemistryModule::MasterRecvPkgs(worker_list_t &w_list,
MPI_Get_count(&probe_status, MPI_DOUBLE, &size); MPI_Get_count(&probe_status, MPI_DOUBLE, &size);
MPI_Recv(w_list[p - 1].send_addr, size, MPI_DOUBLE, p, LOOP_WORK, MPI_Recv(w_list[p - 1].send_addr, size, MPI_DOUBLE, p, LOOP_WORK,
this->group_comm, MPI_STATUS_IGNORE); this->group_comm, MPI_STATUS_IGNORE);
handled = true; // Only LOOP_WORK completes a work package
w_list[p - 1].has_work = 0;
pkg_to_recv -= 1;
free_workers++;
break; break;
} }
case LOOP_CTRL: { case LOOP_CTRL: {
recv_ctrl_a = MPI_Wtime(); recv_ctrl_a = MPI_Wtime();
/* layout of buffer is [phreeqc][surrogate] */
MPI_Get_count(&probe_status, MPI_DOUBLE, &size); MPI_Get_count(&probe_status, MPI_DOUBLE, &size);
recv_buffer.resize(size); recv_buffer.resize(size);
MPI_Recv(recv_buffer.data(), size, MPI_DOUBLE, p, LOOP_CTRL, MPI_Recv(recv_buffer.data(), size, MPI_DOUBLE, p, LOOP_CTRL,
this->group_comm, MPI_STATUS_IGNORE); this->group_comm, MPI_STATUS_IGNORE);
int half = size / 2;
std::copy(recv_buffer.begin(), recv_buffer.begin() + half,
w_list[p - 1].send_addr);
/*
if (w_list[p - 1].surrogate_addr == nullptr) {
throw std::runtime_error("MasterRecvPkgs: surrogate_addr is null");
}*/
std::copy(recv_buffer.begin() + (size / 2), recv_buffer.begin() + size,
w_list[p - 1].surrogate_addr);
recv_ctrl_b = MPI_Wtime(); recv_ctrl_b = MPI_Wtime();
recv_ctrl_t += recv_ctrl_b - recv_ctrl_a; recv_ctrl_t += recv_ctrl_b - recv_ctrl_a;
handled = true; // Collect PHREEQC rows for control cells
const std::size_t cells_per_batch =
static_cast<std::size_t>(size) /
static_cast<std::size_t>(this->prop_count);
for (std::size_t i = 0; i < cells_per_batch; i++) {
std::vector<double> cell_output(
recv_buffer.begin() + this->prop_count * i,
recv_buffer.begin() + this->prop_count * (i + 1));
this->control_batch.push_back(std::move(cell_output));
}
break; break;
} }
default: { default: {
@ -374,11 +375,6 @@ inline void poet::ChemistryModule::MasterRecvPkgs(worker_list_t &w_list,
std::to_string(probe_status.MPI_TAG)); std::to_string(probe_status.MPI_TAG));
} }
} }
if (handled) {
w_list[p - 1].has_work = 0;
pkg_to_recv -= 1;
free_workers++;
}
} }
} }
@ -450,11 +446,6 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
MPI_Barrier(this->group_comm); MPI_Barrier(this->group_comm);
this->control_enabled = this->control_module->getControlIntervalEnabled();
if (this->control_enabled) {
this->mpi_surr_buffer.assign(this->n_cells * this->prop_count, 0.0);
}
static uint32_t iteration = 0; static uint32_t iteration = 0;
/* start time measurement of sequential part */ /* start time measurement of sequential part */
@ -473,9 +464,7 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
pkg_to_recv = wp_sizes_vector.size(); pkg_to_recv = wp_sizes_vector.size();
workpointer_t work_pointer = mpi_buffer.begin(); workpointer_t work_pointer = mpi_buffer.begin();
workpointer_t sur_pointer = this->mpi_surr_buffer.begin(); workpointer_t sur_pointer = mpi_buffer.begin();
//(this->control_enabled ? this->mpi_surr_buffer.begin()
// : mpi_buffer.end());
worker_list_t worker_list(this->comm_size - 1); worker_list_t worker_list(this->comm_size - 1);
free_workers = this->comm_size - 1; free_workers = this->comm_size - 1;
@ -523,28 +512,36 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
chem_field = out_vec; chem_field = out_vec;
/* do master stuff */ /* do master stuff */
if (this->control_enabled) { if (!this->control_batch.empty()) {
std::cout << "[Master] Control logic enabled for this iteration." std::cout << "[Master] Processing " << this->control_batch.size()
<< std::endl; << " control cells for comparison." << std::endl;
std::vector<double> sur_unshuffled{mpi_surr_buffer};
shuf_a = MPI_Wtime(); /* using mpi-buffer because we need cell-major layout*/
unshuffleField(this->mpi_surr_buffer, this->n_cells, this->prop_count, std::vector<std::vector<double>> surrogate_batch;
wp_sizes_vector.size(), sur_unshuffled); surrogate_batch.reserve(this->control_batch.size());
shuf_b = MPI_Wtime();
this->shuf_t += shuf_b - shuf_a;
size_t N = out_vec.size(); for (const auto &element : this->control_batch) {
if (N != sur_unshuffled.size()) {
std::cerr << "[MASTER DBG] size mismatch out_vec=" << N for (size_t i = 0; i < this->n_cells; i++) {
<< " sur_unshuffled=" << sur_unshuffled.size() << std::endl; uint32_t curr_cell_id = mpi_buffer[this->prop_count * i];
if (curr_cell_id == element[0]) {
std::vector<double> surrogate_output(
mpi_buffer.begin() + this->prop_count * i,
mpi_buffer.begin() + this->prop_count * (i + 1));
surrogate_batch.push_back(surrogate_output);
break;
}
}
} }
metrics_a = MPI_Wtime(); metrics_a = MPI_Wtime();
control_module->computeSpeciesErrorMetrics(out_vec, sur_unshuffled, control_module->computeSpeciesErrorMetrics(this->control_batch, surrogate_batch, 1);
this->n_cells);
metrics_b = MPI_Wtime(); metrics_b = MPI_Wtime();
this->metrics_t += metrics_b - metrics_a; this->metrics_t += metrics_b - metrics_a;
// Clear for next control iteration
this->control_batch.clear();
} }
/* start time measurement of master chemistry */ /* start time measurement of master chemistry */

View File

@ -59,37 +59,6 @@ void poet::ChemistryModule::WorkerLoop() {
MPI_INT, 0, this->group_comm); MPI_INT, 0, this->group_comm);
break; break;
} }
/*
case CHEM_WARMUP_PHASE: {
int warmup_flag = 0;
ChemBCast(&warmup_flag, 1, MPI_INT);
this->warmup_enabled = (warmup_flag == 1);
//std::cout << "Warmup phase is " << this->warmup_enabled << std::endl;
break;
}
case CHEM_DHT_ENABLE: {
int dht_flag = 0;
ChemBCast(&dht_flag, 1, MPI_INT);
this->dht_enabled = (dht_flag == 1);
//std::cout << "DHT_enabled is " << this->dht_enabled << std::endl;
break;
}
case CHEM_IP_ENABLE: {
int interp_flag = 0;
ChemBCast(&interp_flag, 1, MPI_INT);
this->interp_enabled = (interp_flag == 1);
;
std::cout << "Interp_enabled is " << this->interp_enabled << std::endl;
break;
}
case CHEM_CTRL_ENABLE: {
int control_flag = 0;
ChemBCast(&control_flag, 1, MPI_INT);
this->control_enabled = (control_flag == 1);
std::cout << "Control_enabled is " << this->control_enabled << std::endl;
break;
}
*/
case CHEM_WORK_LOOP: { case CHEM_WORK_LOOP: {
WorkerProcessPkgs(timings, iteration); WorkerProcessPkgs(timings, iteration);
break; break;
@ -147,6 +116,36 @@ void poet::ChemistryModule::WorkerProcessPkgs(struct worker_s &timings,
} }
} }
void poet::ChemistryModule::ProcessControlWorkPackage(
std::vector<std::vector<double>> &input, double current_sim_time, double dt,
struct worker_s &timings) {
double phreeqc_start, phreeqc_end;
if (input.empty()) {
return;
}
poet::WorkPackage control_wp(input.size());
control_wp.input = input;
std::vector<double> mpi_buffer(control_wp.size * this->prop_count);
phreeqc_start = MPI_Wtime();
WorkerRunWorkPackage(control_wp, current_sim_time, dt);
phreeqc_end = MPI_Wtime();
timings.ctrl_phreeqc_t += phreeqc_end - phreeqc_start;
for (std::size_t wp_i = 0; wp_i < control_wp.size; wp_i++) {
std::copy(control_wp.output[wp_i].begin(), control_wp.output[wp_i].end(),
mpi_buffer.begin() + this->prop_count * wp_i);
}
MPI_Request send_req;
MPI_Isend(mpi_buffer.data(), mpi_buffer.size(), MPI_DOUBLE, 0, LOOP_CTRL,
MPI_COMM_WORLD, &send_req);
MPI_Wait(&send_req, MPI_STATUS_IGNORE);
}
void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
int double_count, int double_count,
struct worker_s &timings) { struct worker_s &timings) {
@ -165,6 +164,9 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
int flags; int flags;
std::vector<double> mpi_buffer(count); std::vector<double> mpi_buffer(count);
static int control_cells_processed = 0;
static std::vector<std::vector<double>> control_batch;
/* receive */ /* receive */
MPI_Recv(mpi_buffer.data(), count, MPI_DOUBLE, 0, LOOP_WORK, this->group_comm, MPI_Recv(mpi_buffer.data(), count, MPI_DOUBLE, 0, LOOP_WORK, this->group_comm,
MPI_STATUS_IGNORE); MPI_STATUS_IGNORE);
@ -234,31 +236,28 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
} }
} }
/* if control iteration: create copy surrogate results (output and mappings) /* process cells to be monitored in a seperate workpackage */
and then set them to zero, give this to phreeqc */
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
uint32_t cell_id = s_curr_wp.input[wp_i][0];
if (this->ctrl_cell_ids.find(cell_id) != this->ctrl_cell_ids.end() &&
s_curr_wp.mapping[wp_i] != CHEM_PQC) {
control_batch.push_back(s_curr_wp.input[wp_i]);
control_cells_processed++;
poet::WorkPackage s_curr_wp_control = s_curr_wp; if (control_batch.size() == s_curr_wp.size ||
control_cells_processed == this->ctrl_cell_ids.size()) {
ProcessControlWorkPackage(control_batch, current_sim_time, dt, timings);
control_batch.clear();
/* control_cells_processed = 0;
if (control_enabled) { }
ctrl_cp_start = MPI_Wtime();
for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) {
s_curr_wp_control.output[wp_i] = std::vector<double>(this->prop_count, 0.0);
s_curr_wp_control.mapping[wp_i] = CHEM_PQC;
} }
ctrl_cp_end = MPI_Wtime();
timings.ctrl_t += ctrl_cp_end - ctrl_cp_start;
} }
*/
phreeqc_time_start = MPI_Wtime(); phreeqc_time_start = MPI_Wtime();
WorkerRunWorkPackage(control_enabled ? s_curr_wp_control : s_curr_wp, WorkerRunWorkPackage(s_curr_wp, current_sim_time, dt);
current_sim_time, dt);
phreeqc_time_end = MPI_Wtime(); phreeqc_time_end = MPI_Wtime();
@ -267,53 +266,15 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
mpi_buffer.begin() + this->prop_count * wp_i); mpi_buffer.begin() + this->prop_count * wp_i);
} }
/*
if (control_enabled) {
ctrl_start = MPI_Wtime();
std::size_t sur_wp_offset = s_curr_wp.size * this->prop_count;
mpi_buffer.resize(count + sur_wp_offset);
for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) {
std::copy(s_curr_wp_control.output[wp_i].begin(),
s_curr_wp_control.output[wp_i].end(),
mpi_buffer.begin() + this->prop_count * wp_i);
}
// s_curr_wp only contains the interpolated data
// copy surrogate output after the the pqc output, mpi_buffer[pqc][interp]
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
// only copy if surrogate was used
if (s_curr_wp.mapping[wp_i] != CHEM_PQC) {
std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(),
mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i);
} else {
// if pqc was used, copy pqc results again
std::copy(s_curr_wp_control.output[wp_i].begin(),
s_curr_wp_control.output[wp_i].end(),
mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i);
}
}
count += sur_wp_offset;
ctrl_end = MPI_Wtime();
timings.ctrl_t += ctrl_end - ctrl_start;
} else {
}
*/
/* send results to master */ /* send results to master */
MPI_Request send_req; MPI_Request send_req;
MPI_Isend(mpi_buffer.data(), count, MPI_DOUBLE, 0, LOOP_WORK, MPI_COMM_WORLD,
int mpi_tag = control_enabled ? LOOP_CTRL : LOOP_WORK;
MPI_Isend(mpi_buffer.data(), count, MPI_DOUBLE, 0, mpi_tag, MPI_COMM_WORLD,
&send_req); &send_req);
if (dht_enabled || interp_enabled || warmup_enabled) { if (dht_enabled || interp_enabled || warmup_enabled) {
/* write results to DHT */ /* write results to DHT */
dht_fill_start = MPI_Wtime(); dht_fill_start = MPI_Wtime();
dht->fillDHT(control_enabled ? s_curr_wp_control : s_curr_wp); dht->fillDHT(s_curr_wp);
dht_fill_end = MPI_Wtime(); dht_fill_end = MPI_Wtime();
if (interp_enabled || warmup_enabled) { if (interp_enabled || warmup_enabled) {

View File

@ -13,20 +13,10 @@ void poet::ControlModule::updateControlIteration(const uint32_t &iter,
double prep_a, prep_b; double prep_a, prep_b;
prep_a = MPI_Wtime(); prep_a = MPI_Wtime();
if (control_interval == 0) {
control_interval_enabled = false;
return;
}
global_iteration = iter; global_iteration = iter;
initiateWarmupPhase(dht_enabled, interp_enabled); initiateWarmupPhase(dht_enabled, interp_enabled);
control_interval_enabled =
(control_interval > 0 && iter % control_interval == 0);
if (control_interval_enabled) {
MSG("[Control] Control interval enabled at iteration " +
std::to_string(iter));
}
prep_b = MPI_Wtime(); prep_b = MPI_Wtime();
this->prep_t += prep_b - prep_a; this->prep_t += prep_b - prep_a;
} }
@ -35,21 +25,20 @@ void poet::ControlModule::initiateWarmupPhase(bool dht_enabled,
bool interp_enabled) { bool interp_enabled) {
// user requested DHT/INTEP? keep them disabled but enable warmup-phase so // user requested DHT/INTEP? keep them disabled but enable warmup-phase so
if (global_iteration <= control_interval || rollback_enabled) { if (rollback_enabled) {
chem->SetWarmupEnabled(true); chem->SetWarmupEnabled(true);
chem->SetDhtEnabled(false); chem->SetDhtEnabled(false);
chem->SetInterpEnabled(false); chem->SetInterpEnabled(false);
MSG("Warmup enabled until next control interval at iteration " +
std::to_string(control_interval) + ".");
if (rollback_enabled) { MSG("Warmup enabled until next control interval at iteration " +
std::to_string(penalty_interval) + ".");
if (sur_disabled_counter > 0) { if (sur_disabled_counter > 0) {
--sur_disabled_counter; --sur_disabled_counter;
MSG("Rollback counter: " + std::to_string(sur_disabled_counter)); MSG("Rollback counter: " + std::to_string(sur_disabled_counter));
} else { } else {
rollback_enabled = false; rollback_enabled = false;
} }
}
return; return;
} }
@ -58,23 +47,23 @@ void poet::ControlModule::initiateWarmupPhase(bool dht_enabled,
chem->SetInterpEnabled(interp_enabled); chem->SetInterpEnabled(interp_enabled);
} }
void poet::ControlModule::applyControlLogic(ChemistryModule &chem, void poet::ControlModule::applyControlLogic(DiffusionModule &diffusion,
uint32_t &iter) { uint32_t &iter) {
if (!control_interval_enabled) {
return;
}
writeCheckpointAndMetrics(chem, iter);
if (checkAndRollback(chem, iter) && rollback_count < 4) { writeCheckpointAndMetrics(diffusion, iter);
if (checkAndRollback(diffusion, iter) && rollback_count < 3) {
rollback_enabled = true; rollback_enabled = true;
rollback_count++; rollback_count++;
sur_disabled_counter = control_interval; sur_disabled_counter = penalty_interval;
MSG("Interpolation disabled for the next " + MSG("Interpolation disabled for the next " +
std::to_string(control_interval) + "."); std::to_string(penalty_interval) + ".");
} }
} }
void poet::ControlModule::writeCheckpointAndMetrics(ChemistryModule &chem, void poet::ControlModule::writeCheckpointAndMetrics(DiffusionModule &diffusion,
uint32_t iter) { uint32_t iter) {
double w_check_a, w_check_b, stats_a, stats_b; double w_check_a, w_check_b, stats_a, stats_b;
@ -82,7 +71,7 @@ void poet::ControlModule::writeCheckpointAndMetrics(ChemistryModule &chem,
w_check_a = MPI_Wtime(); w_check_a = MPI_Wtime();
write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5", write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
{.field = chem.getField(), .iteration = iter}); {.field = diffusion.getField(), .iteration = iter});
w_check_b = MPI_Wtime(); w_check_b = MPI_Wtime();
this->w_check_t += w_check_b - w_check_a; this->w_check_t += w_check_b - w_check_a;
@ -93,7 +82,7 @@ void poet::ControlModule::writeCheckpointAndMetrics(ChemistryModule &chem,
this->stats_t += stats_b - stats_a; this->stats_t += stats_b - stats_a;
} }
bool poet::ControlModule::checkAndRollback(ChemistryModule &chem, bool poet::ControlModule::checkAndRollback(DiffusionModule &diffusion,
uint32_t &iter) { uint32_t &iter) {
double r_check_a, r_check_b; double r_check_a, r_check_b;
@ -119,7 +108,7 @@ bool poet::ControlModule::checkAndRollback(ChemistryModule &chem,
" → rolling back to iteration " + std::to_string(rollback_iter)); " → rolling back to iteration " + std::to_string(rollback_iter));
r_check_a = MPI_Wtime(); r_check_a = MPI_Wtime();
Checkpoint_s checkpoint_read{.field = chem.getField()}; Checkpoint_s checkpoint_read{.field = diffusion.getField()};
read_checkpoint(out_dir, read_checkpoint(out_dir,
"checkpoint" + std::to_string(rollback_iter) + ".hdf5", "checkpoint" + std::to_string(rollback_iter) + ".hdf5",
checkpoint_read); checkpoint_read);
@ -135,8 +124,9 @@ bool poet::ControlModule::checkAndRollback(ChemistryModule &chem,
} }
void poet::ControlModule::computeSpeciesErrorMetrics( void poet::ControlModule::computeSpeciesErrorMetrics(
const std::vector<double> &reference_values, std::vector<std::vector<double>> &reference_values,
const std::vector<double> &surrogate_values, const uint32_t size_per_prop) { std::vector<std::vector<double>> &surrogate_values,
const uint32_t size_per_prop) {
SpeciesErrorMetrics metrics(this->species_names.size(), global_iteration, SpeciesErrorMetrics metrics(this->species_names.size(), global_iteration,
rollback_count); rollback_count);
@ -148,25 +138,16 @@ void poet::ControlModule::computeSpeciesErrorMetrics(
return; return;
} }
const std::size_t expected = // Loop over species (rows in the data structure)
static_cast<std::size_t>(this->species_names.size()) * size_per_prop; for (size_t species_idx = 0; species_idx < reference_values.size(); species_idx++) {
if (reference_values.size() < expected) {
std::cerr << "[CTRL ERROR] input vectors too small: expected >= "
<< expected << " entries, got " << reference_values.size()
<< "\n";
return;
}
for (uint32_t i = 0; i < this->species_names.size(); ++i) {
double err_sum = 0.0; double err_sum = 0.0;
double sqr_err_sum = 0.0; double sqr_err_sum = 0.0;
uint32_t base_idx = i * size_per_prop; uint32_t count = 0;
int count = 0; // Loop over control cells (columns in the data structure)
for (size_t cell_idx = 0; cell_idx < size_per_prop; cell_idx++) {
for (uint32_t j = 0; j < size_per_prop; ++j) { const double ref_value = reference_values[species_idx][cell_idx];
const double ref_value = reference_values[base_idx + j]; const double sur_value = surrogate_values[species_idx][cell_idx];
const double sur_value = surrogate_values[base_idx + j];
const double ZERO_ABS = 1e-13; const double ZERO_ABS = 1e-13;
if (std::isnan(ref_value) || std::isnan(sur_value)) { if (std::isnan(ref_value) || std::isnan(sur_value)) {
@ -177,17 +158,28 @@ void poet::ControlModule::computeSpeciesErrorMetrics(
if (std::abs(sur_value) >= ZERO_ABS) { if (std::abs(sur_value) >= ZERO_ABS) {
err_sum += 1.0; err_sum += 1.0;
sqr_err_sum += 1.0; sqr_err_sum += 1.0;
count++;
} }
// Both zero: skip (don't increment count)
} }
// Both zero: skip
else { else {
double alpha = 1.0 - (sur_value / ref_value); double alpha = 1.0 - (sur_value / ref_value);
err_sum += std::abs(alpha); err_sum += std::abs(alpha);
sqr_err_sum += alpha * alpha; sqr_err_sum += alpha * alpha;
count++;
} }
} }
metrics.mape[i] = 100.0 * (err_sum / size_per_prop);
metrics.rrmse[i] = std::sqrt(sqr_err_sum / size_per_prop); // Store metrics for this species after processing all cells
if (count > 0) {
metrics.mape[species_idx] = 100.0 * (err_sum / size_per_prop);
metrics.rrmse[species_idx] = std::sqrt(sqr_err_sum / size_per_prop);
} else {
metrics.mape[species_idx] = 0.0;
metrics.rrmse[species_idx] = 0.0;
} }
}
// Push metrics to history once after processing all species
metricsHistory.push_back(metrics); metricsHistory.push_back(metrics);
} }

View File

@ -3,6 +3,7 @@
#include "Base/Macros.hpp" #include "Base/Macros.hpp"
#include "Chemistry/ChemistryModule.hpp" #include "Chemistry/ChemistryModule.hpp"
#include "Transport/DiffusionModule.hpp"
#include "poet.hpp" #include "poet.hpp"
#include <cstdint> #include <cstdint>
@ -12,6 +13,7 @@
namespace poet { namespace poet {
class ChemistryModule; class ChemistryModule;
class DiffusionModule;
class ControlModule { class ControlModule {
@ -27,7 +29,7 @@ public:
void initiateWarmupPhase(bool dht_enabled, bool interp_enabled); void initiateWarmupPhase(bool dht_enabled, bool interp_enabled);
bool checkAndRollback(ChemistryModule &chem, uint32_t &iter); bool checkAndRollback(DiffusionModule &diffusion, uint32_t &iter);
struct SpeciesErrorMetrics { struct SpeciesErrorMetrics {
std::vector<double> mape; std::vector<double> mape;
@ -40,8 +42,9 @@ public:
rollback_count(counter) {} rollback_count(counter) {}
}; };
void computeSpeciesErrorMetrics(const std::vector<double> &reference_values, void computeSpeciesErrorMetrics(
const std::vector<double> &surrogate_values, std::vector<std::vector<double>> &reference_values,
std::vector<std::vector<double>> &surrogate_values,
const uint32_t size_per_prop); const uint32_t size_per_prop);
std::vector<SpeciesErrorMetrics> metricsHistory; std::vector<SpeciesErrorMetrics> metricsHistory;
@ -49,7 +52,6 @@ public:
struct ControlSetup { struct ControlSetup {
std::string out_dir; std::string out_dir;
std::uint32_t checkpoint_interval; std::uint32_t checkpoint_interval;
std::uint32_t control_interval;
std::vector<std::string> species_names; std::vector<std::string> species_names;
std::vector<double> mape_threshold; std::vector<double> mape_threshold;
std::vector<uint32_t> ctrl_cell_ids; std::vector<uint32_t> ctrl_cell_ids;
@ -58,26 +60,19 @@ public:
void enableControlLogic(const ControlSetup &setup) { void enableControlLogic(const ControlSetup &setup) {
this->out_dir = setup.out_dir; this->out_dir = setup.out_dir;
this->checkpoint_interval = setup.checkpoint_interval; this->checkpoint_interval = setup.checkpoint_interval;
this->control_interval = setup.control_interval;
this->species_names = setup.species_names; this->species_names = setup.species_names;
this->mape_threshold = setup.mape_threshold; this->mape_threshold = setup.mape_threshold;
this->ctrl_cell_ids = setup.ctrl_cell_ids; this->ctrl_cell_ids = setup.ctrl_cell_ids;
} }
bool getControlIntervalEnabled() const { void applyControlLogic(DiffusionModule &diffusion, uint32_t &iter);
return this->control_interval_enabled;
}
void applyControlLogic(ChemistryModule &chem, uint32_t &iter); void writeCheckpointAndMetrics(DiffusionModule &diffusion, uint32_t iter);
void writeCheckpointAndMetrics(ChemistryModule &chem, uint32_t iter);
auto getGlobalIteration() const noexcept { return global_iteration; } auto getGlobalIteration() const noexcept { return global_iteration; }
void setChemistryModule(poet::ChemistryModule *c) { chem = c; } void setChemistryModule(poet::ChemistryModule *c) { chem = c; }
auto getControlInterval() const { return this->control_interval; }
std::vector<double> getMapeThreshold() const { return this->mape_threshold; } std::vector<double> getMapeThreshold() const { return this->mape_threshold; }
std::vector<uint32_t> getCtrlCellIds() const { return this->ctrl_cell_ids; } std::vector<uint32_t> getCtrlCellIds() const { return this->ctrl_cell_ids; }
@ -94,12 +89,11 @@ public:
private: private:
bool rollback_enabled = false; bool rollback_enabled = false;
bool control_interval_enabled = false;
poet::ChemistryModule *chem = nullptr; poet::ChemistryModule *chem = nullptr;
std::uint32_t penalty_interval = 50;
std::uint32_t checkpoint_interval = 0; std::uint32_t checkpoint_interval = 0;
std::uint32_t control_interval = 0;
std::uint32_t global_iteration = 0; std::uint32_t global_iteration = 0;
std::uint32_t rollback_count = 0; std::uint32_t rollback_count = 0;
std::uint32_t sur_disabled_counter = 0; std::uint32_t sur_disabled_counter = 0;

View File

@ -3,7 +3,7 @@
#include <string> #include <string>
#include "Datatypes.hpp" #include "Datatypes.hpp"
int write_checkpoint(const std::string &file_path, struct Checkpoint_s &&checkpoint);
int read_checkpoint(const std::string &file_path, struct Checkpoint_s &checkpoint); int write_checkpoint(const std::string &dir_path, const std::string &file_name, struct Checkpoint_s &&checkpoint);
int read_checkpoint(const std::string &dir_path, const std::string &file_name, struct Checkpoint_s &checkpoint);

View File

@ -1,13 +1,20 @@
#include "IO/Datatypes.hpp" #include "IO/Datatypes.hpp"
#include <cstdint> #include <cstdint>
#include <highfive/H5Easy.hpp> #include <highfive/H5Easy.hpp>
#include <filesystem>
int write_checkpoint(const std::string &file_path, struct Checkpoint_s &&checkpoint){ namespace fs = std::filesystem;
int write_checkpoint(const std::string &dir_path, const std::string &file_name, struct Checkpoint_s &&checkpoint){
if (!fs::exists(dir_path)) {
std::cerr << "Directory does not exist: " << dir_path << std::endl;
return -1;
}
fs::path file_path = fs::path(dir_path) / file_name;
// TODO: errorhandling // TODO: errorhandling
H5Easy::File file(file_path, H5Easy::File::Overwrite); H5Easy::File file(file_path, H5Easy::File::Overwrite);
H5Easy::dump(file, "/MetaParam/Iterations", checkpoint.iteration); H5Easy::dump(file, "/MetaParam/Iterations", checkpoint.iteration);
H5Easy::dump(file, "/Grid/Names", checkpoint.field.GetProps()); H5Easy::dump(file, "/Grid/Names", checkpoint.field.GetProps());
H5Easy::dump(file, "/Grid/Chemistry", checkpoint.field.As2DVector()); H5Easy::dump(file, "/Grid/Chemistry", checkpoint.field.As2DVector());
@ -15,7 +22,14 @@ int write_checkpoint(const std::string &file_path, struct Checkpoint_s &&checkpo
return 0; return 0;
} }
int read_checkpoint(const std::string &file_path, struct Checkpoint_s &checkpoint){ int read_checkpoint(const std::string &dir_path, const std::string &file_name, struct Checkpoint_s &checkpoint){
fs::path file_path = fs::path(dir_path) / file_name;
if (!fs::exists(file_path)) {
std::cerr << "File does not exist: " << file_path << std::endl;
return -1;
}
H5Easy::File file(file_path, H5Easy::File::ReadOnly); H5Easy::File file(file_path, H5Easy::File::ReadOnly);

View File

@ -252,14 +252,12 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("timesteps")); Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("timesteps"));
params.checkpoint_interval = params.checkpoint_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval")); Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval"));
params.control_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("control_interval"));
params.mape_threshold = Rcpp::as<std::vector<double>>( params.mape_threshold = Rcpp::as<std::vector<double>>(
global_rt_setup->operator[]("mape_threshold")); global_rt_setup->operator[]("mape_threshold"));
params.ctrl_cell_ids = Rcpp::as<std::vector<uint32_t>>( params.ctrl_cell_ids = Rcpp::as<std::vector<uint32_t>>(
global_rt_setup->operator[]("ctrl_cell_ids")); global_rt_setup->operator[]("ctrl_cell_ids"));
catch (const std::exception &e) { } catch (const std::exception &e) {
ERRMSG("Error while parsing R scripts: " + std::string(e.what())); ERRMSG("Error while parsing R scripts: " + std::string(e.what()));
return ParseRet::PARSER_ERROR; return ParseRet::PARSER_ERROR;
} }
@ -269,8 +267,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
// HACK: this is a step back as the order and also the count of fields is // HACK: this is a step back as the order and also the count of fields is
// predefined, but it will change in the future // predefined, but it will change in the future
void call_master_iter_end(RInside & R, const Field &trans, void call_master_iter_end(RInside &R, const Field &trans, const Field &chem) {
const Field &chem) {
R["TMP"] = Rcpp::wrap(trans.AsVector()); R["TMP"] = Rcpp::wrap(trans.AsVector());
R["TMP_PROPS"] = Rcpp::wrap(trans.GetProps()); R["TMP_PROPS"] = Rcpp::wrap(trans.GetProps());
R.parseEval(std::string("state_T <- setNames(data.frame(matrix(TMP, nrow=" + R.parseEval(std::string("state_T <- setNames(data.frame(matrix(TMP, nrow=" +
@ -287,8 +284,8 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
*global_rt_setup = R["setup"]; *global_rt_setup = R["setup"];
} }
static Rcpp::List RunMasterLoop( static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
RInsidePOET & R, RuntimeParameters & params, DiffusionModule & diffusion, DiffusionModule &diffusion,
ChemistryModule &chem, ControlModule &control) { ChemistryModule &chem, ControlModule &control) {
/* Iteration Count is dynamic, retrieving value from R (is only needed by /* Iteration Count is dynamic, retrieving value from R (is only needed by
@ -414,7 +411,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" + MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +
std::to_string(maxiter)); std::to_string(maxiter));
control.applyControlLogic(chem, iter); control.applyControlLogic(diffusion, iter);
// MSG(); // MSG();
} // END SIMULATION LOOP } // END SIMULATION LOOP
@ -433,10 +430,8 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
if (params.use_dht) { if (params.use_dht) {
chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits()); chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits());
chem_profiling["dht_evictions"] = chem_profiling["dht_evictions"] = Rcpp::wrap(chem.GetWorkerDHTEvictions());
Rcpp::wrap(chem.GetWorkerDHTEvictions()); chem_profiling["dht_get_time"] = Rcpp::wrap(chem.GetWorkerDHTGetTimings());
chem_profiling["dht_get_time"] =
Rcpp::wrap(chem.GetWorkerDHTGetTimings());
chem_profiling["dht_fill_time"] = chem_profiling["dht_fill_time"] =
Rcpp::wrap(chem.GetWorkerDHTFillTimings()); Rcpp::wrap(chem.GetWorkerDHTFillTimings());
} }
@ -452,8 +447,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
Rcpp::wrap(chem.GetWorkerInterpolationFunctionCallTimings()); Rcpp::wrap(chem.GetWorkerInterpolationFunctionCallTimings());
chem_profiling["interp_calls"] = chem_profiling["interp_calls"] =
Rcpp::wrap(chem.GetWorkerInterpolationCalls()); Rcpp::wrap(chem.GetWorkerInterpolationCalls());
chem_profiling["interp_cached"] = chem_profiling["interp_cached"] = Rcpp::wrap(chem.GetWorkerPHTCacheHits());
Rcpp::wrap(chem.GetWorkerPHTCacheHits());
} }
Rcpp::List profiling; Rcpp::List profiling;
@ -466,7 +460,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
return profiling; return profiling;
} }
static void getControlCellIds(const vector<uint32_t> &ids, int root, static void getControlCellIds(std::vector<std::uint32_t> &ids, int root,
MPI_Comm comm) { MPI_Comm comm) {
std::uint32_t n_ids = 0; std::uint32_t n_ids = 0;
int rank; int rank;
@ -477,7 +471,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
n_ids = ids.size(); n_ids = ids.size();
} }
// broadcast size of id vector // broadcast size of id vector
MPI_Bcast(n_ids, 1, MPI_UINT32_T, root, comm); MPI_Bcast(&n_ids, 1, MPI_UINT32_T, root, comm);
// worker // worker
if (!is_master) { if (!is_master) {
@ -489,7 +483,6 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
} }
} }
std::vector<std::string> getSpeciesNames(const Field &&field, int root, std::vector<std::string> getSpeciesNames(const Field &&field, int root,
MPI_Comm comm) { MPI_Comm comm) {
std::uint32_t n_elements; std::uint32_t n_elements;
@ -508,8 +501,8 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
for (std::uint32_t i = 0; i < n_elements; i++) { for (std::uint32_t i = 0; i < n_elements; i++) {
n_string_size = field.GetProps()[i].size(); n_string_size = field.GetProps()[i].size();
MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD); MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
MPI_Bcast(const_cast<char *>(field.GetProps()[i].c_str()), MPI_Bcast(const_cast<char *>(field.GetProps()[i].c_str()), n_string_size,
n_string_size, MPI_CHAR, root, MPI_COMM_WORLD); MPI_CHAR, root, MPI_COMM_WORLD);
} }
return field.GetProps(); return field.GetProps();
@ -645,9 +638,13 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
chemistry.masterEnableSurrogates(surr_setup); chemistry.masterEnableSurrogates(surr_setup);
/* broadcast control cell ids before simulation starts */
getControlCellIds(run_params.ctrl_cell_ids, 0, MPI_COMM_WORLD);
chemistry.SetControlCellIds(run_params.ctrl_cell_ids);
const ControlModule::ControlSetup ctrl_setup = { const ControlModule::ControlSetup ctrl_setup = {
run_params.out_dir, // added run_params.out_dir, // added
run_params.checkpoint_interval, run_params.control_interval, run_params.checkpoint_interval,
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
run_params.mape_threshold}; run_params.mape_threshold};
@ -710,8 +707,8 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
"'/timings.', setup$out_ext));"; "'/timings.', setup$out_ext));";
R.parseEval(r_vis_code); R.parseEval(r_vis_code);
MSG("Done! Results are stored as R objects into <" + MSG("Done! Results are stored as R objects into <" + run_params.out_dir +
run_params.out_dir + "/timings." + run_params.out_ext); "/timings." + run_params.out_ext);
} }
} }

View File

@ -54,7 +54,7 @@ struct RuntimeParameters {
std::uint32_t checkpoint_interval = 0; std::uint32_t checkpoint_interval = 0;
std::uint32_t control_interval = 0; std::uint32_t control_interval = 0;
std::vector<double> mape_threshold; std::vector<double> mape_threshold;
std::vector<double> ctrl_cell_ids; std::vector<uint32t_t> ctrl_cell_ids;
static constexpr std::uint32_t WORK_PACKAGE_SIZE_DEFAULT = 32; static constexpr std::uint32_t WORK_PACKAGE_SIZE_DEFAULT = 32;
std::uint32_t work_package_size = WORK_PACKAGE_SIZE_DEFAULT; std::uint32_t work_package_size = WORK_PACKAGE_SIZE_DEFAULT;