migrate: separate control logic from ChemistryModule into dedicated ControlModule

This commit is contained in:
rastogi 2025-10-19 11:49:52 +02:00 committed by Max Lübke
parent 354ce2e1bb
commit 71269166ea
10 changed files with 771 additions and 686 deletions

View File

@ -33,6 +33,7 @@ add_library(POETLib
Chemistry/SurrogateModels/HashFunctions.cpp Chemistry/SurrogateModels/HashFunctions.cpp
Chemistry/SurrogateModels/InterpolationModule.cpp Chemistry/SurrogateModels/InterpolationModule.cpp
Chemistry/SurrogateModels/ProximityHashTable.cpp Chemistry/SurrogateModels/ProximityHashTable.cpp
Control/ControlModule.cpp
) )
set(POET_TUG_APPROACH "Implicit" CACHE STRING "tug numerical approach to use") set(POET_TUG_APPROACH "Implicit" CACHE STRING "tug numerical approach to use")

View File

@ -4,17 +4,14 @@
#include "DataStructures/Field.hpp" #include "DataStructures/Field.hpp"
#include "DataStructures/NamedVector.hpp" #include "DataStructures/NamedVector.hpp"
#include "ChemistryDefs.hpp" #include "ChemistryDefs.hpp"
#include "Control/ControlModule.hpp"
#include "Init/InitialList.hpp" #include "Init/InitialList.hpp"
#include "NameDouble.h" #include "NameDouble.h"
#include "SurrogateModels/DHT_Wrapper.hpp" #include "SurrogateModels/DHT_Wrapper.hpp"
#include "SurrogateModels/Interpolation.hpp" #include "SurrogateModels/Interpolation.hpp"
#include "poet.hpp"
#include "PhreeqcRunner.hpp" #include "PhreeqcRunner.hpp"
#include <array> #include <array>
#include <cstdint> #include <cstdint>
#include <map> #include <map>
@ -23,14 +20,13 @@
#include <string> #include <string>
#include <vector> #include <vector>
namespace poet namespace poet {
{ class ControlModule;
/** /**
* \brief Wrapper around PhreeqcRM to provide POET specific parallelization with * \brief Wrapper around PhreeqcRM to provide POET specific parallelization with
* easy access. * easy access.
*/ */
class ChemistryModule class ChemistryModule {
{
public: public:
/** /**
* Creates a new instance of Chemistry module with given grid cell count, work * Creates a new instance of Chemistry module with given grid cell count, work
@ -73,14 +69,12 @@ namespace poet
*/ */
auto GetChemistryTime() const { return this->chem_t; } auto GetChemistryTime() const { return this->chem_t; }
void setFilePadding(std::uint32_t maxiter) void setFilePadding(std::uint32_t maxiter) {
{
this->file_pad = this->file_pad =
static_cast<std::uint8_t>(std::ceil(std::log10(maxiter + 1))); static_cast<std::uint8_t>(std::ceil(std::log10(maxiter + 1)));
} }
struct SurrogateSetup struct SurrogateSetup {
{
std::vector<std::string> prop_names; std::vector<std::string> prop_names;
std::array<double, 2> base_totals; std::array<double, 2> base_totals;
bool has_het_ids; bool has_het_ids;
@ -97,8 +91,7 @@ namespace poet
bool ai_surrogate_enabled; bool ai_surrogate_enabled;
}; };
void masterEnableSurrogates(const SurrogateSetup &setup) void masterEnableSurrogates(const SurrogateSetup &setup) {
{
// FIXME: This is a hack to get the prop_names and prop_count from the setup // FIXME: This is a hack to get the prop_names and prop_count from the setup
this->prop_names = setup.prop_names; this->prop_names = setup.prop_names;
this->prop_count = setup.prop_names.size(); this->prop_count = setup.prop_names.size();
@ -109,19 +102,16 @@ namespace poet
this->base_totals = setup.base_totals; this->base_totals = setup.base_totals;
if (this->dht_enabled || this->interp_enabled) if (this->dht_enabled || this->interp_enabled) {
{
this->initializeDHT(setup.dht_size_mb, this->params.dht_species, this->initializeDHT(setup.dht_size_mb, this->params.dht_species,
setup.has_het_ids); setup.has_het_ids);
if (setup.dht_snaps != DHT_SNAPS_DISABLED) if (setup.dht_snaps != DHT_SNAPS_DISABLED) {
{
this->setDHTSnapshots(setup.dht_snaps, setup.dht_out_dir); this->setDHTSnapshots(setup.dht_snaps, setup.dht_out_dir);
} }
} }
if (this->interp_enabled) if (this->interp_enabled) {
{
this->initializeInterp(setup.interp_bucket_size, setup.interp_size_mb, this->initializeInterp(setup.interp_bucket_size, setup.interp_size_mb,
setup.interp_min_entries, setup.interp_min_entries,
this->params.interp_species); this->params.interp_species);
@ -143,8 +133,7 @@ namespace poet
/** /**
* Enumerating DHT file options * Enumerating DHT file options
*/ */
enum enum {
{
DHT_SNAPS_DISABLED = 0, //!< disabled file output DHT_SNAPS_DISABLED = 0, //!< disabled file output
DHT_SNAPS_SIMEND, //!< only output of snapshot after simulation DHT_SNAPS_SIMEND, //!< only output of snapshot after simulation
DHT_SNAPS_ITEREND //!< output snapshots after each iteration DHT_SNAPS_ITEREND //!< output snapshots after each iteration
@ -185,7 +174,6 @@ namespace poet
*/ */
auto GetMasterLoopTime() const { return this->send_recv_t; } auto GetMasterLoopTime() const { return this->send_recv_t; }
auto GetMasterCtrlLogicTime() const { return this->ctrl_t; } auto GetMasterCtrlLogicTime() const { return this->ctrl_t; }
auto GetMasterCtrlBcastTime() const { return this->bcast_ctrl_t; } auto GetMasterCtrlBcastTime() const { return this->bcast_ctrl_t; }
@ -249,8 +237,7 @@ namespace poet
* *
* \param enabled True if print progressbar, false if not. * \param enabled True if print progressbar, false if not.
*/ */
void setProgressBarPrintout(bool enabled) void setProgressBarPrintout(bool enabled) {
{
this->print_progessbar = enabled; this->print_progessbar = enabled;
}; };
@ -270,30 +257,6 @@ namespace poet
std::vector<int> ai_surrogate_validity_vector; std::vector<int> ai_surrogate_validity_vector;
RuntimeParameters *runtime_params = nullptr;
struct SimulationErrorStats
{
std::vector<double> mape;
std::vector<double> rrmse;
uint32_t iteration; // iterations in simulation after rollbacks
uint32_t rollback_count;
SimulationErrorStats(size_t species_count, uint32_t iter, uint32_t counter)
: mape(species_count, 0.0),
rrmse(species_count, 0.0),
iteration(iter),
rollback_count(counter){}
};
std::vector<SimulationErrorStats> error_history;
static void computeSpeciesErrors(const std::vector<double> &reference_values,
const std::vector<double> &surrogate_values,
uint32_t size_per_prop,
uint32_t species_count,
SimulationErrorStats &species_error_stats);
protected: protected:
void initializeDHT(uint32_t size_mb, void initializeDHT(uint32_t size_mb,
const NamedVector<std::uint32_t> &key_species, const NamedVector<std::uint32_t> &key_species,
@ -305,14 +268,13 @@ namespace poet
std::uint32_t min_entries, std::uint32_t min_entries,
const NamedVector<std::uint32_t> &key_species); const NamedVector<std::uint32_t> &key_species);
enum enum {
{
CHEM_FIELD_INIT, CHEM_FIELD_INIT,
CHEM_DHT_ENABLE, CHEM_DHT_ENABLE,
CHEM_DHT_SIGNIF_VEC, CHEM_DHT_SIGNIF_VEC,
CHEM_DHT_SNAPS, CHEM_DHT_SNAPS,
CHEM_DHT_READ_FILE, CHEM_DHT_READ_FILE,
CHEM_INTERP, CHEM_IP, // Control Flag
CHEM_IP_ENABLE, CHEM_IP_ENABLE,
CHEM_IP_MIN_ENTRIES, CHEM_IP_MIN_ENTRIES,
CHEM_IP_SIGNIF_VEC, CHEM_IP_SIGNIF_VEC,
@ -322,15 +284,9 @@ namespace poet
CHEM_AI_BCAST_VALIDITY CHEM_AI_BCAST_VALIDITY
}; };
enum enum { LOOP_WORK, LOOP_END, LOOP_CTRL };
{
LOOP_WORK,
LOOP_END,
LOOP_CTRL
};
enum enum {
{
WORKER_PHREEQC, WORKER_PHREEQC,
WORKER_CTRL_ITER, WORKER_CTRL_ITER,
WORKER_DHT_GET, WORKER_DHT_GET,
@ -350,8 +306,7 @@ namespace poet
std::vector<uint32_t> dht_hits; std::vector<uint32_t> dht_hits;
std::vector<uint32_t> dht_evictions; std::vector<uint32_t> dht_evictions;
struct worker_s struct worker_s {
{
double phreeqc_t = 0.; double phreeqc_t = 0.;
double dht_get = 0.; double dht_get = 0.;
double dht_fill = 0.; double dht_fill = 0.;
@ -359,8 +314,7 @@ namespace poet
double ctrl_t = 0.; double ctrl_t = 0.;
}; };
struct worker_info_s struct worker_info_s {
{
char has_work = 0; char has_work = 0;
double *send_addr; double *send_addr;
double *surrogate_addr; double *surrogate_addr;
@ -372,9 +326,10 @@ namespace poet
void MasterRunParallel(double dt); void MasterRunParallel(double dt);
void MasterRunSequential(); void MasterRunSequential();
void MasterSendPkgs(worker_list_t &w_list, workpointer_t &work_pointer, workpointer_t &sur_pointer, void MasterSendPkgs(worker_list_t &w_list, workpointer_t &work_pointer,
int &pkg_to_send, int &count_pkgs, int &free_workers, workpointer_t &sur_pointer, int &pkg_to_send,
double dt, uint32_t iteration, uint32_t control_iteration, int &count_pkgs, int &free_workers, double dt,
uint32_t iteration, uint32_t control_iteration,
const std::vector<uint32_t> &wp_sizes_vector); const std::vector<uint32_t> &wp_sizes_vector);
void MasterRecvPkgs(worker_list_t &w_list, int &pkg_to_recv, bool to_send, void MasterRecvPkgs(worker_list_t &w_list, int &pkg_to_recv, bool to_send,
int &free_workers); int &free_workers);
@ -433,13 +388,11 @@ namespace poet
static constexpr uint32_t BUFFER_OFFSET = 6; static constexpr uint32_t BUFFER_OFFSET = 6;
inline void ChemBCast(void *buf, int count, MPI_Datatype datatype) const inline void ChemBCast(void *buf, int count, MPI_Datatype datatype) const {
{
MPI_Bcast(buf, count, datatype, 0, this->group_comm); MPI_Bcast(buf, count, datatype, 0, this->group_comm);
} }
inline void PropagateFunctionType(int &type) const inline void PropagateFunctionType(int &type) const {
{
ChemBCast(&type, 1, MPI_INT); ChemBCast(&type, 1, MPI_INT);
} }
double simtime = 0.; double simtime = 0.;
@ -469,7 +422,9 @@ namespace poet
std::unique_ptr<PhreeqcRunner> pqc_runner; std::unique_ptr<PhreeqcRunner> pqc_runner;
std::vector<double> sur_shuffled; std::unique_ptr<poet::ControlModule> ctrl_module;
//std::vector<double> sur_shuffled;
}; };
} // namespace poet } // namespace poet

View File

@ -3,7 +3,6 @@
#include <algorithm> #include <algorithm>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <iomanip>
#include <mpi.h> #include <mpi.h>
#include <vector> #include <vector>
@ -166,39 +165,6 @@ std::vector<uint32_t> poet::ChemistryModule::GetWorkerPHTCacheHits() const {
return ret; return ret;
} }
void poet::ChemistryModule::computeSpeciesErrors(const std::vector<double> &reference_values,
const std::vector<double> &surrogate_values,
uint32_t size_per_prop,
uint32_t species_count,
SimulationErrorStats &species_error_stats) {
for (uint32_t i = 0; i < species_count; ++i) {
double err_sum = 0.0;
double sqr_err_sum = 0.0;
uint32_t base_idx = i * size_per_prop;
for (uint32_t j = 0; j < size_per_prop; ++j) {
const double ref_value = reference_values[base_idx + j];
const double sur_value = surrogate_values[base_idx + j];
if (ref_value == 0.0) {
if (sur_value != 0.0) {
err_sum += 1.0;
sqr_err_sum += 1.0;
}
// Both zero: skip
} else {
double alpha = 1.0 - (sur_value / ref_value);
err_sum += std::abs(alpha);
sqr_err_sum += alpha * alpha;
}
}
species_error_stats.mape[i] = 100.0 * (err_sum / size_per_prop);
species_error_stats.rrmse[i] =
(size_per_prop > 0) ? std::sqrt(sqr_err_sum / size_per_prop) : 0.0;
}
}
inline std::vector<int> shuffleVector(const std::vector<int> &in_vector, inline std::vector<int> shuffleVector(const std::vector<int> &in_vector,
uint32_t size_per_prop, uint32_t size_per_prop,
uint32_t wp_count) { uint32_t wp_count) {
@ -269,8 +235,8 @@ inline void printProgressbar(int count_pkgs, int n_wp, int barWidth = 70) {
inline void poet::ChemistryModule::MasterSendPkgs( inline void poet::ChemistryModule::MasterSendPkgs(
worker_list_t &w_list, workpointer_t &work_pointer, worker_list_t &w_list, workpointer_t &work_pointer,
workpointer_t &sur_pointer, int &pkg_to_send, int &count_pkgs, workpointer_t &sur_pointer, int &pkg_to_send, int &count_pkgs,
int &free_workers, double dt, uint32_t iteration, int &free_workers, double dt, uint32_t iteration, uint32_t control_interval,
uint32_t control_interval, const std::vector<uint32_t> &wp_sizes_vector) { const std::vector<uint32_t> &wp_sizes_vector) {
/* declare variables */ /* declare variables */
int local_work_package_size; int local_work_package_size;
@ -461,28 +427,9 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
/* start time measurement of broadcasting interpolation status */ /* start time measurement of broadcasting interpolation status */
ctrl_bcast_a = MPI_Wtime(); ctrl_bcast_a = MPI_Wtime();
ftype = CHEM_IP;
ftype = CHEM_INTERP;
PropagateFunctionType(ftype); PropagateFunctionType(ftype);
ctrl_module->BCastControlFlags();
int interp_flag = 0;
int dht_fill_flag = 0;
if(this->runtime_params->rollback_enabled){
this->interp_enabled = false;
this->dht_fill_during_rollback = true;
interp_flag = 0;
dht_fill_flag = 1;
}
else {
this->interp_enabled = true;
this->dht_fill_during_rollback = false;
interp_flag = 1;
dht_fill_flag = 0;
}
ChemBCast(&interp_flag, 1, MPI_INT);
ChemBCast(&dht_fill_flag, 1, MPI_INT);
/* end time measurement of broadcasting interpolation status */ /* end time measurement of broadcasting interpolation status */
ctrl_bcast_b = MPI_Wtime(); ctrl_bcast_b = MPI_Wtime();
this->bcast_ctrl_t += ctrl_bcast_b - ctrl_bcast_a; this->bcast_ctrl_t += ctrl_bcast_b - ctrl_bcast_a;
@ -494,11 +441,12 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
static uint32_t iteration = 0; static uint32_t iteration = 0;
uint32_t control_logic_enabled = this->runtime_params->control_interval_enabled ? 1 : 0; uint32_t control_logic_enabled =
ctrl_module->control_interval_enabled ? 1 : 0;
if (control_logic_enabled) { if (control_logic_enabled) {
sur_shuffled.clear(); ctrl_module->sur_shuffled.clear();
sur_shuffled.reserve(this->n_cells * this->prop_count); ctrl_module->sur_shuffled.reserve(this->n_cells * this->prop_count);
} }
/* start time measurement of sequential part */ /* start time measurement of sequential part */
@ -511,14 +459,14 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count, shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count,
wp_sizes_vector.size()); wp_sizes_vector.size());
this->sur_shuffled.resize(mpi_buffer.size()); ctrl_module->sur_shuffled.resize(mpi_buffer.size());
/* setup local variables */ /* setup local variables */
pkg_to_send = wp_sizes_vector.size(); pkg_to_send = wp_sizes_vector.size();
pkg_to_recv = wp_sizes_vector.size(); pkg_to_recv = wp_sizes_vector.size();
workpointer_t work_pointer = mpi_buffer.begin(); workpointer_t work_pointer = mpi_buffer.begin();
workpointer_t sur_pointer = sur_shuffled.begin(); workpointer_t sur_pointer = ctrl_module->sur_shuffled.begin();
worker_list_t worker_list(this->comm_size - 1); worker_list_t worker_list(this->comm_size - 1);
free_workers = this->comm_size - 1; free_workers = this->comm_size - 1;
@ -571,25 +519,19 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
/* start time measurement of control logic */ /* start time measurement of control logic */
ctrl_a = MPI_Wtime(); ctrl_a = MPI_Wtime();
if (control_logic_enabled && !this->runtime_params->rollback_enabled) { if (control_logic_enabled && !ctrl_module->rollback_enabled) {
std::cout << "[Master] Control logic enabled for this iteration." << std::endl;
std::vector<double> sur_unshuffled{sur_shuffled};; std::vector<double> sur_unshuffled{ctrl_module->sur_shuffled};
unshuffleField(ctrl_module->sur_shuffled, this->n_cells, this->prop_count,
unshuffleField(sur_shuffled, this->n_cells, this->prop_count,
wp_sizes_vector.size(), sur_unshuffled); wp_sizes_vector.size(), sur_unshuffled);
SimulationErrorStats stats(this->prop_count, this->runtime_params->global_iter, this->runtime_params->rollback_counter); ctrl_module->computeSpeciesErrors(out_vec, sur_unshuffled, this->n_cells);
computeSpeciesErrors(out_vec, sur_unshuffled, this->n_cells, this->prop_count, stats);
error_history.push_back(stats);
} }
/* end time measurement of control logic */ /* end time measurement of control logic */
ctrl_b = MPI_Wtime(); ctrl_b = MPI_Wtime();
this->ctrl_t += ctrl_b - ctrl_a; this->ctrl_t += ctrl_b - ctrl_a;
/* start time measurement of master chemistry */ /* start time measurement of master chemistry */
sim_e_chemistry = MPI_Wtime(); sim_e_chemistry = MPI_Wtime();

View File

@ -67,7 +67,7 @@ namespace poet
MPI_INT, 0, this->group_comm); MPI_INT, 0, this->group_comm);
break; break;
} }
case CHEM_INTERP: case CHEM_IP:
{ {
int interp_flag = 0; int interp_flag = 0;
int dht_fill_flag = 0; int dht_fill_flag = 0;

View File

@ -0,0 +1,131 @@
#include "ControlModule.hpp"
#include "IO/Datatypes.hpp"
#include "IO/HDF5Functions.hpp"
#include "IO/StatsIO.hpp"
#include <cmath>
bool poet::ControlModule::isControlIteration(uint32_t iter) {
control_interval_enabled = (iter % control_interval == 0);
if (control_interval_enabled) {
MSG("[Control] Control interval triggered at iteration " +
std::to_string(iter));
}
return control_interval_enabled;
}
void poet::ControlModule::beginIteration() {
if (rollback_enabled) {
if (sur_disabled_counter > 0) {
sur_disabled_counter--;
MSG("Rollback counter: " + std::to_string(sur_disabled_counter));
} else {
rollback_enabled = false;
}
}
}
void poet::ControlModule::endIteration(uint32_t iter) {
/* Writing a checkpointing */
if (checkpoint_interval > 0 && iter % checkpoint_interval == 0) {
MSG("Writing checkpoint of iteration " + std::to_string(iter));
write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
{.field = chem->getField(), .iteration = iter});
}
/* Control Logic*/
if (control_interval_enabled && !rollback_enabled) {
writeStatsToCSV(error_history, species_names, out_dir,
"stats_overview");
if (triggerRollbackIfExceeded(*chem, *params, iter)) {
rollback_enabled = true;
rollback_counter++;
sur_disabled_counter = control_interval;
MSG("Interpolation disabled for the next " +
std::to_string(control_interval) + ".");
}
}
}
void poet::ControlModule::BCastControlFlags() {
int interp_flag = rollback_enabled ? 0 : 1;
int dht_fill_flag = rollback_enabled ? 1 : 0;
chem->ChemBCast(&interp_flag, 1, MPI_INT);
chem->ChemBCast(&dht_fill_flag, 1, MPI_INT);
}
bool poet::ControlModule::triggerRollbackIfExceeded(ChemistryModule &chem,
RuntimeParameters &params,
uint32_t &iter) {
if (error_history.empty()) {
MSG("No error history yet; skipping rollback check.");
return false;
}
const auto &mape = chem.error_history.back().mape;
const auto &props = chem.getField().GetProps();
for (uint32_t i = 0; i < params.mape_threshold.size(); ++i) {
// Skip invalid entries
if (mape[i] == 0) {
continue;
}
bool mape_exceeded = mape[i] > params.mape_threshold[i];
if (mape_exceeded) {
uint32_t rollback_iter = ((iter - 1) / params.checkpoint_interval) *
params.checkpoint_interval;
MSG("[THRESHOLD EXCEEDED] " + props[i] +
" has MAPE = " + std::to_string(mape[i]) +
" exceeding threshold = " + std::to_string(params.mape_threshold[i]) +
" → rolling back to iteration " + std::to_string(rollback_iter));
Checkpoint_s checkpoint_read{.field = chem.getField()};
read_checkpoint(params.out_dir,
"checkpoint" + std::to_string(rollback_iter) + ".hdf5",
checkpoint_read);
iter = checkpoint_read.iteration;
return true;
}
}
MSG("All species are within their MAPE and RRMSE thresholds.");
return false;
}
void poet::ControlModule::computeSpeciesErrors(
const std::vector<double> &reference_values,
const std::vector<double> &surrogate_values, uint32_t size_per_prop) {
SimulationErrorStats species_error_stats(species_count, params->global_iter,
rollback_counter);
for (uint32_t i = 0; i < species_count; ++i) {
double err_sum = 0.0;
double sqr_err_sum = 0.0;
uint32_t base_idx = i * size_per_prop;
for (uint32_t j = 0; j < size_per_prop; ++j) {
const double ref_value = reference_values[base_idx + j];
const double sur_value = surrogate_values[base_idx + j];
if (ref_value == 0.0) {
if (sur_value != 0.0) {
err_sum += 1.0;
sqr_err_sum += 1.0;
}
// Both zero: skip
} else {
double alpha = 1.0 - (sur_value / ref_value);
err_sum += std::abs(alpha);
sqr_err_sum += alpha * alpha;
}
}
species_error_stats.mape[i] = 100.0 * (err_sum / size_per_prop);
species_error_stats.rrmse[i] =
(size_per_prop > 0) ? std::sqrt(sqr_err_sum / size_per_prop) : 0.0;
}
error_history.push_back(species_error_stats);
}

View File

@ -0,0 +1,110 @@
#ifndef CONTROLMODULE_H_
#define CONTROLMODULE_H_
#include "Base/Macros.hpp"
#include "Chemistry/ChemistryModule.hpp"
#include "poet.hpp"
#include <cstdint>
#include <string>
#include <vector>
namespace poet {
class ChemistryModule;
class ControlModule {
public:
ControlModule(RuntimeParameters *run_params, ChemistryModule *chem_module)
: params(run_params), chem(chem_module) {};
/* Control configuration*/
std::vector<std::string> species_names;
uint32_t species_count = 0;
std::string out_dir;
bool rollback_enabled = false;
bool control_interval_enabled = false;
std::uint32_t global_iter = 0;
std::uint32_t sur_disabled_counter = 0;
std::uint32_t rollback_counter = 0;
std::uint32_t checkpoint_interval = 0;
std::uint32_t control_interval = 0;
std::vector<double> mape_threshold;
std::vector<double> rrmse_threshold;
double ctrl_t = 0.;
double bcast_ctrl_t = 0.;
double recv_ctrl_t = 0.;
/* Buffer for shuffled surrogate data */
std::vector<double> sur_shuffled;
bool isControlIteration(uint32_t iter);
void beginIteration();
void endIteration(uint32_t iter);
void BCastControlFlags();
bool triggerRollbackIfExceeded(ChemistryModule &chem,
RuntimeParameters &params, uint32_t &iter);
struct SimulationErrorStats {
std::vector<double> mape;
std::vector<double> rrmse;
uint32_t iteration; // iterations in simulation after rollbacks
uint32_t rollback_count;
SimulationErrorStats(size_t species_count, uint32_t iter, uint32_t counter)
: mape(species_count, 0.0), rrmse(species_count, 0.0), iteration(iter),
rollback_count(counter) {}
};
static void computeSpeciesErrors(const std::vector<double> &reference_values,
const std::vector<double> &surrogate_values,
uint32_t size_per_prop);
std::vector<SimulationErrorStats> error_history;
struct ControlSetup {
std::string out_dir;
std::uint32_t checkpoint_interval;
std::uint32_t control_interval;
std::uint32_t species_count;
std::vector<std::string> species_names;
std::vector<double> mape_threshold;
std::vector<double> rrmse_threshold;
};
void enableControlLogic(const ControlSetup &setup) {
out_dir = setup.out_dir;
checkpoint_interval = setup.checkpoint_interval;
control_interval = setup.control_interval;
species_count = setup.species_count;
species_names = setup.species_names;
mape_threshold = setup.mape_threshold;
rrmse_threshold = setup.rrmse_threshold;
}
/* Profiling getters */
auto GetMasterCtrlLogicTime() const { return this->ctrl_t; }
auto GetMasterCtrlBcastTime() const { return this->bcast_ctrl_t; }
auto GetMasterRecvCtrlLogicTime() const { return this->recv_ctrl_t; }
private:
RuntimeParameters *params;
ChemistryModule *chem;
};
} // namespace poet
#endif // CONTROLMODULE_H_

View File

@ -7,7 +7,7 @@
namespace poet namespace poet
{ {
void writeStatsToCSV(const std::vector<ChemistryModule::SimulationErrorStats> &all_stats, void writeStatsToCSV(const std::vector<ControlModule::SimulationErrorStats> &all_stats,
const std::vector<std::string> &species_names, const std::vector<std::string> &species_names,
const std::string &out_dir, const std::string &out_dir,
const std::string &filename) const std::string &filename)

View File

@ -1,9 +1,9 @@
#include <string> #include <string>
#include "Chemistry/ChemistryModule.hpp" #include "Control/ControlModule.hpp"
namespace poet namespace poet
{ {
void writeStatsToCSV(const std::vector<ChemistryModule::SimulationErrorStats> &all_stats, void writeStatsToCSV(const std::vector<ControlModule::SimulationErrorStats> &all_stats,
const std::vector<std::string> &species_names, const std::vector<std::string> &species_names,
const std::string &out_dir, const std::string &out_dir,
const std::string &filename); const std::string &filename);

View File

@ -25,10 +25,8 @@
#include "Base/RInsidePOET.hpp" #include "Base/RInsidePOET.hpp"
#include "CLI/CLI.hpp" #include "CLI/CLI.hpp"
#include "Chemistry/ChemistryModule.hpp" #include "Chemistry/ChemistryModule.hpp"
#include "Control/ControlManager.hpp"
#include "DataStructures/Field.hpp" #include "DataStructures/Field.hpp"
#include "IO/Datatypes.hpp"
#include "IO/HDF5Functions.hpp"
#include "IO/StatsIO.hpp"
#include "Init/InitialList.hpp" #include "Init/InitialList.hpp"
#include "Transport/DiffusionModule.hpp" #include "Transport/DiffusionModule.hpp"
@ -68,8 +66,7 @@ static poet::DEFunc ReadRObj_R;
static poet::DEFunc SaveRObj_R; static poet::DEFunc SaveRObj_R;
static poet::DEFunc source_R; static poet::DEFunc source_R;
static void init_global_functions(RInside &R) static void init_global_functions(RInside &R) {
{
R.parseEval(kin_r_library); R.parseEval(kin_r_library);
master_init_R = DEFunc("master_init"); master_init_R = DEFunc("master_init");
master_iteration_end_R = DEFunc("master_iteration_end"); master_iteration_end_R = DEFunc("master_iteration_end");
@ -92,15 +89,9 @@ static void init_global_functions(RInside &R)
// R.parseEval("mysetup$state_C <- TMP"); // R.parseEval("mysetup$state_C <- TMP");
// } // }
enum ParseRet enum ParseRet { PARSER_OK, PARSER_ERROR, PARSER_HELP };
{
PARSER_OK,
PARSER_ERROR,
PARSER_HELP
};
int parseInitValues(int argc, char **argv, RuntimeParameters &params) int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
{
CLI::App app{"POET - Potsdam rEactive Transport simulator"}; CLI::App app{"POET - Potsdam rEactive Transport simulator"};
@ -182,12 +173,9 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params)
"Output directory of the simulation") "Output directory of the simulation")
->required(); ->required();
try try {
{
app.parse(argc, argv); app.parse(argc, argv);
} } catch (const CLI::ParseError &e) {
catch (const CLI::ParseError &e)
{
app.exit(e); app.exit(e);
return -1; return -1;
} }
@ -199,16 +187,14 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params)
if (params.as_qs) if (params.as_qs)
params.out_ext = "qs"; params.out_ext = "qs";
if (MY_RANK == 0) if (MY_RANK == 0) {
{
// MSG("Complete results storage is " + BOOL_PRINT(simparams.store_result)); // MSG("Complete results storage is " + BOOL_PRINT(simparams.store_result));
MSG("Output format/extension is " + params.out_ext); MSG("Output format/extension is " + params.out_ext);
MSG("Work Package Size: " + std::to_string(params.work_package_size)); MSG("Work Package Size: " + std::to_string(params.work_package_size));
MSG("DHT is " + BOOL_PRINT(params.use_dht)); MSG("DHT is " + BOOL_PRINT(params.use_dht));
MSG("AI Surrogate is " + BOOL_PRINT(params.use_ai_surrogate)); MSG("AI Surrogate is " + BOOL_PRINT(params.use_ai_surrogate));
if (params.use_dht) if (params.use_dht) {
{
// MSG("DHT strategy is " + std::to_string(simparams.dht_strategy)); // MSG("DHT strategy is " + std::to_string(simparams.dht_strategy));
// MDL: these should be outdated (?) // MDL: these should be outdated (?)
// MSG("DHT key default digits (ignored if 'signif_vector' is " // MSG("DHT key default digits (ignored if 'signif_vector' is "
@ -222,8 +208,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params)
// MSG("DHT load file is " + chem_params.dht_file); // MSG("DHT load file is " + chem_params.dht_file);
} }
if (params.use_interp) if (params.use_interp) {
{
MSG("PHT interpolation enabled: " + BOOL_PRINT(params.use_interp)); MSG("PHT interpolation enabled: " + BOOL_PRINT(params.use_interp));
MSG("PHT interp-size = " + std::to_string(params.interp_size)); MSG("PHT interp-size = " + std::to_string(params.interp_size));
MSG("PHT interp-min = " + std::to_string(params.interp_min_entries)); MSG("PHT interp-min = " + std::to_string(params.interp_min_entries));
@ -251,8 +236,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params)
// // log before rounding? // // log before rounding?
// R["dht_log"] = simparams.dht_log; // R["dht_log"] = simparams.dht_log;
try try {
{
Rcpp::List init_params_(ReadRObj_R(init_file)); Rcpp::List init_params_(ReadRObj_R(init_file));
params.init_params = init_params_; params.init_params = init_params_;
@ -269,13 +253,11 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params)
Rcpp::as<uint32_t>(global_rt_setup->operator[]("control_interval")); Rcpp::as<uint32_t>(global_rt_setup->operator[]("control_interval"));
params.checkpoint_interval = params.checkpoint_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval")); Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval"));
params.mape_threshold = params.mape_threshold = Rcpp::as<std::vector<double>>(
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("mape_threshold")); global_rt_setup->operator[]("mape_threshold"));
params.rrmse_threshold = params.rrmse_threshold = Rcpp::as<std::vector<double>>(
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("rrmse_threshold")); global_rt_setup->operator[]("rrmse_threshold"));
} } catch (const std::exception &e) {
catch (const std::exception &e)
{
ERRMSG("Error while parsing R scripts: " + std::string(e.what())); ERRMSG("Error while parsing R scripts: " + std::string(e.what()));
return ParseRet::PARSER_ERROR; return ParseRet::PARSER_ERROR;
} }
@ -285,8 +267,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params)
// HACK: this is a step back as the order and also the count of fields is // HACK: this is a step back as the order and also the count of fields is
// predefined, but it will change in the future // predefined, but it will change in the future
void call_master_iter_end(RInside &R, const Field &trans, const Field &chem) void call_master_iter_end(RInside &R, const Field &trans, const Field &chem) {
{
R["TMP"] = Rcpp::wrap(trans.AsVector()); R["TMP"] = Rcpp::wrap(trans.AsVector());
R["TMP_PROPS"] = Rcpp::wrap(trans.GetProps()); R["TMP_PROPS"] = Rcpp::wrap(trans.GetProps());
R.parseEval(std::string("state_T <- setNames(data.frame(matrix(TMP, nrow=" + R.parseEval(std::string("state_T <- setNames(data.frame(matrix(TMP, nrow=" +
@ -303,53 +284,15 @@ void call_master_iter_end(RInside &R, const Field &trans, const Field &chem)
*global_rt_setup = R["setup"]; *global_rt_setup = R["setup"];
} }
bool triggerRollbackIfExceeded(ChemistryModule &chem, RuntimeParameters &params, uint32_t &current_iteration)
{
const auto &mape = chem.error_history.back().mape;
const auto &rrmse = chem.error_history.back().rrmse;
const auto &props = chem.getField().GetProps();
for (uint32_t i = 0; i < params.mape_threshold.size(); ++i)
{
// Skip invalid entries
if ((mape[i] == 0 && rrmse[i] == 0))
continue;
bool mape_exceeded = mape[i] > params.mape_threshold[i];
bool rrmse_exceeded = rrmse[i] > params.rrmse_threshold[i];
if (mape_exceeded || rrmse_exceeded)
{
uint32_t rollback_iter = ((current_iteration - 1) / params.checkpoint_interval) * params.checkpoint_interval;
std::string metric = mape_exceeded ? "MAPE" : "RRMSE";
double value = mape_exceeded ? mape[i] : rrmse[i];
double threshold = mape_exceeded ? params.mape_threshold[i] : params.rrmse_threshold[i];
MSG("[THRESHOLD EXCEEDED] " + props[i] + " has " + metric + " = " +
std::to_string(value) + " exceeding threshold = " + std::to_string(threshold) +
" → rolling back to iteration " + std::to_string(rollback_iter));
Checkpoint_s checkpoint_read{.field = chem.getField()};
read_checkpoint(params.out_dir, "checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read);
current_iteration = checkpoint_read.iteration;
return true; // rollback happened
}
}
MSG("All species are within their MAPE and RRMSE thresholds.");
return false;
}
static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params, static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
DiffusionModule &diffusion, DiffusionModule &diffusion,
ChemistryModule &chem) ChemistryModule &chem, ControlModule &control) {
{
/* Iteration Count is dynamic, retrieving value from R (is only needed by /* Iteration Count is dynamic, retrieving value from R (is only needed by
* master for the following loop) */ * master for the following loop) */
uint32_t maxiter = params.timesteps.size(); uint32_t maxiter = params.timesteps.size();
if (params.print_progress) if (params.print_progress) {
{
chem.setProgressBarPrintout(true); chem.setProgressBarPrintout(true);
} }
R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps()); R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps());
@ -359,9 +302,10 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
double dSimTime{0}; double dSimTime{0};
double chkTime = 0.0; double chkTime = 0.0;
for (uint32_t iter = 1; iter < maxiter + 1; iter++) for (uint32_t iter = 1; iter < maxiter + 1; iter++) {
{
// Rollback countdowm // Rollback countdowm
/*
if (params.rollback_enabled) { if (params.rollback_enabled) {
if (params.sur_disabled_counter > 0) { if (params.sur_disabled_counter > 0) {
--params.sur_disabled_counter; --params.sur_disabled_counter;
@ -370,9 +314,12 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
params.rollback_enabled = false; params.rollback_enabled = false;
} }
} }
*/
control.beginIteration(iter);
params.global_iter = iter; // params.global_iter = iter;
params.control_interval_enabled = (iter % params.control_interval == 0); control.isControlIteration(iter);
// params.control_interval_enabled = (iter % params.control_interval == 0);
double start_t = MPI_Wtime(); double start_t = MPI_Wtime();
@ -389,13 +336,12 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
/* run transport */ /* run transport */
diffusion.simulate(dt); diffusion.simulate(dt);
chem.runtime_params = &params; // chem.runtime_params = &params;
chem.getField().update(diffusion.getField()); chem.getField().update(diffusion.getField());
// MSG("Chemistry start"); // MSG("Chemistry start");
if (params.use_ai_surrogate) if (params.use_ai_surrogate) {
{
double ai_start_t = MPI_Wtime(); double ai_start_t = MPI_Wtime();
// Save current values from the tug field as predictor for the ai step // Save current values from the tug field as predictor for the ai step
R["TMP"] = Rcpp::wrap(chem.getField().AsVector()); R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
@ -446,8 +392,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
chem.simulate(dt); chem.simulate(dt);
/* AI surrogate iterative training*/ /* AI surrogate iterative training*/
if (params.use_ai_surrogate) if (params.use_ai_surrogate) {
{
double ai_start_t = MPI_Wtime(); double ai_start_t = MPI_Wtime();
R["TMP"] = Rcpp::wrap(chem.getField().AsVector()); R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
@ -487,24 +432,31 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
std::to_string(maxiter)); std::to_string(maxiter));
double chk_start = MPI_Wtime(); double chk_start = MPI_Wtime();
control.endIteration(iter)
/*
if (iter % params.checkpoint_interval == 0) { if (iter % params.checkpoint_interval == 0) {
MSG("Writing checkpoint of iteration " + std::to_string(iter)); MSG("Writing checkpoint of iteration " + std::to_string(iter));
write_checkpoint(params.out_dir, "checkpoint" + std::to_string(iter) + ".hdf5", write_checkpoint(params.out_dir,
"checkpoint" + std::to_string(iter) + ".hdf5",
{.field = chem.getField(), .iteration = iter}); {.field = chem.getField(), .iteration = iter});
} }
if (params.control_interval_enabled && !params.rollback_enabled)
{ if (params.control_interval_enabled && !params.rollback_enabled) {
writeStatsToCSV(chem.error_history, chem.getField().GetProps(), params.out_dir,"stats_overview"); writeStatsToCSV(chem.error_history, chem.getField().GetProps(),
params.out_dir, "stats_overview");
if (triggerRollbackIfExceeded(chem, params, iter)) { if (triggerRollbackIfExceeded(chem, params, iter)) {
params.rollback_enabled = true; params.rollback_enabled = true;
params.rollback_counter++; params.rollback_counter++;
params.sur_disabled_counter = params.control_interval; params.sur_disabled_counter = params.control_interval;
MSG("Interpolation disabled for the next " + std::to_string(params.control_interval) + "."); MSG("Interpolation disabled for the next " +
std::to_string(params.control_interval) + ".");
} }
} }
*/
double chk_end = MPI_Wtime(); double chk_end = MPI_Wtime();
chkTime += chk_end - chk_start; chkTime += chk_end - chk_start;
@ -529,10 +481,10 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
ctrl_profiling["ctrl_logic_master"] = chem.GetMasterCtrlLogicTime(); ctrl_profiling["ctrl_logic_master"] = chem.GetMasterCtrlLogicTime();
ctrl_profiling["bcast_ctrl_logic_master"] = chem.GetMasterCtrlBcastTime(); ctrl_profiling["bcast_ctrl_logic_master"] = chem.GetMasterCtrlBcastTime();
ctrl_profiling["recv_ctrl_logic_maser"] = chem.GetMasterRecvCtrlLogicTime(); ctrl_profiling["recv_ctrl_logic_maser"] = chem.GetMasterRecvCtrlLogicTime();
ctrl_profiling["ctrl_logic_worker"] = Rcpp::wrap(chem.GetWorkerControlTimings()); ctrl_profiling["ctrl_logic_worker"] =
Rcpp::wrap(chem.GetWorkerControlTimings());
if (params.use_dht) if (params.use_dht) {
{
chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits()); chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits());
chem_profiling["dht_evictions"] = Rcpp::wrap(chem.GetWorkerDHTEvictions()); chem_profiling["dht_evictions"] = Rcpp::wrap(chem.GetWorkerDHTEvictions());
chem_profiling["dht_get_time"] = Rcpp::wrap(chem.GetWorkerDHTGetTimings()); chem_profiling["dht_get_time"] = Rcpp::wrap(chem.GetWorkerDHTGetTimings());
@ -540,8 +492,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
Rcpp::wrap(chem.GetWorkerDHTFillTimings()); Rcpp::wrap(chem.GetWorkerDHTFillTimings());
} }
if (params.use_interp) if (params.use_interp) {
{
chem_profiling["interp_w"] = chem_profiling["interp_w"] =
Rcpp::wrap(chem.GetWorkerInterpolationWriteTimings()); Rcpp::wrap(chem.GetWorkerInterpolationWriteTimings());
chem_profiling["interp_r"] = chem_profiling["interp_r"] =
@ -561,15 +512,13 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
profiling["diffusion"] = diffusion_profiling; profiling["diffusion"] = diffusion_profiling;
profiling["ctrl_logic"] = ctrl_profiling; profiling["ctrl_logic"] = ctrl_profiling;
chem.MasterLoopBreak(); chem.MasterLoopBreak();
return profiling; return profiling;
} }
std::vector<std::string> getSpeciesNames(const Field &&field, int root, std::vector<std::string> getSpeciesNames(const Field &&field, int root,
MPI_Comm comm) MPI_Comm comm) {
{
std::uint32_t n_elements; std::uint32_t n_elements;
std::uint32_t n_string_size; std::uint32_t n_string_size;
@ -579,13 +528,11 @@ std::vector<std::string> getSpeciesNames(const Field &&field, int root,
const bool is_master = root == rank; const bool is_master = root == rank;
// first, the master sends all the species names iterative // first, the master sends all the species names iterative
if (is_master) if (is_master) {
{
n_elements = field.GetProps().size(); n_elements = field.GetProps().size();
MPI_Bcast(&n_elements, 1, MPI_UINT32_T, root, MPI_COMM_WORLD); MPI_Bcast(&n_elements, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
for (std::uint32_t i = 0; i < n_elements; i++) for (std::uint32_t i = 0; i < n_elements; i++) {
{
n_string_size = field.GetProps()[i].size(); n_string_size = field.GetProps()[i].size();
MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD); MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
MPI_Bcast(const_cast<char *>(field.GetProps()[i].c_str()), n_string_size, MPI_Bcast(const_cast<char *>(field.GetProps()[i].c_str()), n_string_size,
@ -600,8 +547,7 @@ std::vector<std::string> getSpeciesNames(const Field &&field, int root,
std::vector<std::string> species_names_out(n_elements); std::vector<std::string> species_names_out(n_elements);
for (std::uint32_t i = 0; i < n_elements; i++) for (std::uint32_t i = 0; i < n_elements; i++) {
{
MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD); MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
char recv_buf[n_string_size]; char recv_buf[n_string_size];
@ -614,8 +560,7 @@ std::vector<std::string> getSpeciesNames(const Field &&field, int root,
return species_names_out; return species_names_out;
} }
std::array<double, 2> getBaseTotals(Field &&field, int root, MPI_Comm comm) std::array<double, 2> getBaseTotals(Field &&field, int root, MPI_Comm comm) {
{
std::array<double, 2> base_totals; std::array<double, 2> base_totals;
int rank; int rank;
@ -623,8 +568,7 @@ std::array<double, 2> getBaseTotals(Field &&field, int root, MPI_Comm comm)
const bool is_master = root == rank; const bool is_master = root == rank;
if (is_master) if (is_master) {
{
const auto h_col = field["H"]; const auto h_col = field["H"];
const auto o_col = field["O"]; const auto o_col = field["O"];
@ -639,8 +583,7 @@ std::array<double, 2> getBaseTotals(Field &&field, int root, MPI_Comm comm)
return base_totals; return base_totals;
} }
bool getHasID(Field &&field, int root, MPI_Comm comm) bool getHasID(Field &&field, int root, MPI_Comm comm) {
{
bool has_id; bool has_id;
int rank; int rank;
@ -648,8 +591,7 @@ bool getHasID(Field &&field, int root, MPI_Comm comm)
const bool is_master = root == rank; const bool is_master = root == rank;
if (is_master) if (is_master) {
{
const auto ID_field = field["ID"]; const auto ID_field = field["ID"];
std::set<double> unique_IDs(ID_field.begin(), ID_field.end()); std::set<double> unique_IDs(ID_field.begin(), ID_field.end());
@ -666,8 +608,7 @@ bool getHasID(Field &&field, int root, MPI_Comm comm)
return has_id; return has_id;
} }
int main(int argc, char *argv[]) int main(int argc, char *argv[]) {
{
int world_size; int world_size;
MPI_Init(&argc, &argv); MPI_Init(&argc, &argv);
@ -678,8 +619,7 @@ int main(int argc, char *argv[])
RInsidePOET &R = RInsidePOET::getInstance(); RInsidePOET &R = RInsidePOET::getInstance();
if (MY_RANK == 0) if (MY_RANK == 0) {
{
MSG("Running POET version " + std::string(poet_version)); MSG("Running POET version " + std::string(poet_version));
} }
@ -687,8 +627,7 @@ int main(int argc, char *argv[])
RuntimeParameters run_params; RuntimeParameters run_params;
if (parseInitValues(argc, argv, run_params) != 0) if (parseInitValues(argc, argv, run_params) != 0) {
{
MPI_Finalize(); MPI_Finalize();
return 0; return 0;
} }
@ -713,6 +652,7 @@ int main(int argc, char *argv[])
ChemistryModule chemistry(run_params.work_package_size, ChemistryModule chemistry(run_params.work_package_size,
init_list.getChemistryInit(), MPI_COMM_WORLD); init_list.getChemistryInit(), MPI_COMM_WORLD);
ControlModule control(&run_params, &chemistry);
const ChemistryModule::SurrogateSetup surr_setup = { const ChemistryModule::SurrogateSetup surr_setup = {
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
@ -730,12 +670,21 @@ int main(int argc, char *argv[])
chemistry.masterEnableSurrogates(surr_setup); chemistry.masterEnableSurrogates(surr_setup);
if (MY_RANK > 0) const ControlModule::ControlSetup ctrl_setup = {
{ run_params.out_dir, // added
run_params.checkpoint_interval,
run_params.control_interval,
run_params.species_count,
run_params.species_names,
run_params.mape_threshold,
run_params.rrmse_threshold};
control.enableControlLogic(ctrl_setup);
if (MY_RANK > 0) {
chemistry.WorkerLoop(); chemistry.WorkerLoop();
} } else {
else
{
// R.parseEvalQ("mysetup <- setup"); // R.parseEvalQ("mysetup <- setup");
// // if (MY_RANK == 0) { // get timestep vector from // // if (MY_RANK == 0) { // get timestep vector from
// // grid_init function ... // // // grid_init function ... //
@ -749,8 +698,7 @@ int main(int argc, char *argv[])
R["out_ext"] = run_params.out_ext; R["out_ext"] = run_params.out_ext;
R["out_dir"] = run_params.out_dir; R["out_dir"] = run_params.out_dir;
if (run_params.use_ai_surrogate) if (run_params.use_ai_surrogate) {
{
/* Incorporate ai surrogate from R */ /* Incorporate ai surrogate from R */
R.parseEvalQ(ai_surrogate_r_library); R.parseEvalQ(ai_surrogate_r_library);
/* Use dht species for model input and output */ /* Use dht species for model input and output */
@ -799,8 +747,7 @@ int main(int argc, char *argv[])
MPI_Finalize(); MPI_Finalize();
if (MY_RANK == 0) if (MY_RANK == 0) {
{
MSG("done, bye!"); MSG("done, bye!");
} }

View File

@ -38,8 +38,7 @@ static const inline std::string ai_surrogate_r_library =
R"(@R_AI_SURROGATE_LIB@)"; R"(@R_AI_SURROGATE_LIB@)";
static const inline std::string r_runtime_parameters = "mysetup"; static const inline std::string r_runtime_parameters = "mysetup";
struct RuntimeParameters struct RuntimeParameters {
{
std::string out_dir; std::string out_dir;
std::vector<double> timesteps; std::vector<double> timesteps;