mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-16 12:54:50 +01:00
migrate: separate control logic from ChemistryModule into dedicated ControlModule
This commit is contained in:
parent
354ce2e1bb
commit
71269166ea
@ -33,6 +33,7 @@ add_library(POETLib
|
|||||||
Chemistry/SurrogateModels/HashFunctions.cpp
|
Chemistry/SurrogateModels/HashFunctions.cpp
|
||||||
Chemistry/SurrogateModels/InterpolationModule.cpp
|
Chemistry/SurrogateModels/InterpolationModule.cpp
|
||||||
Chemistry/SurrogateModels/ProximityHashTable.cpp
|
Chemistry/SurrogateModels/ProximityHashTable.cpp
|
||||||
|
Control/ControlModule.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
set(POET_TUG_APPROACH "Implicit" CACHE STRING "tug numerical approach to use")
|
set(POET_TUG_APPROACH "Implicit" CACHE STRING "tug numerical approach to use")
|
||||||
|
|||||||
@ -4,17 +4,14 @@
|
|||||||
|
|
||||||
#include "DataStructures/Field.hpp"
|
#include "DataStructures/Field.hpp"
|
||||||
#include "DataStructures/NamedVector.hpp"
|
#include "DataStructures/NamedVector.hpp"
|
||||||
|
|
||||||
#include "ChemistryDefs.hpp"
|
#include "ChemistryDefs.hpp"
|
||||||
|
#include "Control/ControlModule.hpp"
|
||||||
#include "Init/InitialList.hpp"
|
#include "Init/InitialList.hpp"
|
||||||
#include "NameDouble.h"
|
#include "NameDouble.h"
|
||||||
#include "SurrogateModels/DHT_Wrapper.hpp"
|
#include "SurrogateModels/DHT_Wrapper.hpp"
|
||||||
#include "SurrogateModels/Interpolation.hpp"
|
#include "SurrogateModels/Interpolation.hpp"
|
||||||
|
|
||||||
#include "poet.hpp"
|
|
||||||
|
|
||||||
#include "PhreeqcRunner.hpp"
|
#include "PhreeqcRunner.hpp"
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <map>
|
#include <map>
|
||||||
@ -23,15 +20,14 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace poet
|
namespace poet {
|
||||||
{
|
class ControlModule;
|
||||||
/**
|
/**
|
||||||
* \brief Wrapper around PhreeqcRM to provide POET specific parallelization with
|
* \brief Wrapper around PhreeqcRM to provide POET specific parallelization with
|
||||||
* easy access.
|
* easy access.
|
||||||
*/
|
*/
|
||||||
class ChemistryModule
|
class ChemistryModule {
|
||||||
{
|
public:
|
||||||
public:
|
|
||||||
/**
|
/**
|
||||||
* Creates a new instance of Chemistry module with given grid cell count, work
|
* Creates a new instance of Chemistry module with given grid cell count, work
|
||||||
* package size and communicator.
|
* package size and communicator.
|
||||||
@ -73,14 +69,12 @@ namespace poet
|
|||||||
*/
|
*/
|
||||||
auto GetChemistryTime() const { return this->chem_t; }
|
auto GetChemistryTime() const { return this->chem_t; }
|
||||||
|
|
||||||
void setFilePadding(std::uint32_t maxiter)
|
void setFilePadding(std::uint32_t maxiter) {
|
||||||
{
|
|
||||||
this->file_pad =
|
this->file_pad =
|
||||||
static_cast<std::uint8_t>(std::ceil(std::log10(maxiter + 1)));
|
static_cast<std::uint8_t>(std::ceil(std::log10(maxiter + 1)));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SurrogateSetup
|
struct SurrogateSetup {
|
||||||
{
|
|
||||||
std::vector<std::string> prop_names;
|
std::vector<std::string> prop_names;
|
||||||
std::array<double, 2> base_totals;
|
std::array<double, 2> base_totals;
|
||||||
bool has_het_ids;
|
bool has_het_ids;
|
||||||
@ -97,8 +91,7 @@ namespace poet
|
|||||||
bool ai_surrogate_enabled;
|
bool ai_surrogate_enabled;
|
||||||
};
|
};
|
||||||
|
|
||||||
void masterEnableSurrogates(const SurrogateSetup &setup)
|
void masterEnableSurrogates(const SurrogateSetup &setup) {
|
||||||
{
|
|
||||||
// FIXME: This is a hack to get the prop_names and prop_count from the setup
|
// FIXME: This is a hack to get the prop_names and prop_count from the setup
|
||||||
this->prop_names = setup.prop_names;
|
this->prop_names = setup.prop_names;
|
||||||
this->prop_count = setup.prop_names.size();
|
this->prop_count = setup.prop_names.size();
|
||||||
@ -109,19 +102,16 @@ namespace poet
|
|||||||
|
|
||||||
this->base_totals = setup.base_totals;
|
this->base_totals = setup.base_totals;
|
||||||
|
|
||||||
if (this->dht_enabled || this->interp_enabled)
|
if (this->dht_enabled || this->interp_enabled) {
|
||||||
{
|
|
||||||
this->initializeDHT(setup.dht_size_mb, this->params.dht_species,
|
this->initializeDHT(setup.dht_size_mb, this->params.dht_species,
|
||||||
setup.has_het_ids);
|
setup.has_het_ids);
|
||||||
|
|
||||||
if (setup.dht_snaps != DHT_SNAPS_DISABLED)
|
if (setup.dht_snaps != DHT_SNAPS_DISABLED) {
|
||||||
{
|
|
||||||
this->setDHTSnapshots(setup.dht_snaps, setup.dht_out_dir);
|
this->setDHTSnapshots(setup.dht_snaps, setup.dht_out_dir);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this->interp_enabled)
|
if (this->interp_enabled) {
|
||||||
{
|
|
||||||
this->initializeInterp(setup.interp_bucket_size, setup.interp_size_mb,
|
this->initializeInterp(setup.interp_bucket_size, setup.interp_size_mb,
|
||||||
setup.interp_min_entries,
|
setup.interp_min_entries,
|
||||||
this->params.interp_species);
|
this->params.interp_species);
|
||||||
@ -143,8 +133,7 @@ namespace poet
|
|||||||
/**
|
/**
|
||||||
* Enumerating DHT file options
|
* Enumerating DHT file options
|
||||||
*/
|
*/
|
||||||
enum
|
enum {
|
||||||
{
|
|
||||||
DHT_SNAPS_DISABLED = 0, //!< disabled file output
|
DHT_SNAPS_DISABLED = 0, //!< disabled file output
|
||||||
DHT_SNAPS_SIMEND, //!< only output of snapshot after simulation
|
DHT_SNAPS_SIMEND, //!< only output of snapshot after simulation
|
||||||
DHT_SNAPS_ITEREND //!< output snapshots after each iteration
|
DHT_SNAPS_ITEREND //!< output snapshots after each iteration
|
||||||
@ -185,7 +174,6 @@ namespace poet
|
|||||||
*/
|
*/
|
||||||
auto GetMasterLoopTime() const { return this->send_recv_t; }
|
auto GetMasterLoopTime() const { return this->send_recv_t; }
|
||||||
|
|
||||||
|
|
||||||
auto GetMasterCtrlLogicTime() const { return this->ctrl_t; }
|
auto GetMasterCtrlLogicTime() const { return this->ctrl_t; }
|
||||||
|
|
||||||
auto GetMasterCtrlBcastTime() const { return this->bcast_ctrl_t; }
|
auto GetMasterCtrlBcastTime() const { return this->bcast_ctrl_t; }
|
||||||
@ -249,8 +237,7 @@ namespace poet
|
|||||||
*
|
*
|
||||||
* \param enabled True if print progressbar, false if not.
|
* \param enabled True if print progressbar, false if not.
|
||||||
*/
|
*/
|
||||||
void setProgressBarPrintout(bool enabled)
|
void setProgressBarPrintout(bool enabled) {
|
||||||
{
|
|
||||||
this->print_progessbar = enabled;
|
this->print_progessbar = enabled;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -270,31 +257,7 @@ namespace poet
|
|||||||
|
|
||||||
std::vector<int> ai_surrogate_validity_vector;
|
std::vector<int> ai_surrogate_validity_vector;
|
||||||
|
|
||||||
RuntimeParameters *runtime_params = nullptr;
|
protected:
|
||||||
|
|
||||||
struct SimulationErrorStats
|
|
||||||
{
|
|
||||||
std::vector<double> mape;
|
|
||||||
std::vector<double> rrmse;
|
|
||||||
uint32_t iteration; // iterations in simulation after rollbacks
|
|
||||||
uint32_t rollback_count;
|
|
||||||
|
|
||||||
SimulationErrorStats(size_t species_count, uint32_t iter, uint32_t counter)
|
|
||||||
: mape(species_count, 0.0),
|
|
||||||
rrmse(species_count, 0.0),
|
|
||||||
iteration(iter),
|
|
||||||
rollback_count(counter){}
|
|
||||||
};
|
|
||||||
|
|
||||||
std::vector<SimulationErrorStats> error_history;
|
|
||||||
|
|
||||||
static void computeSpeciesErrors(const std::vector<double> &reference_values,
|
|
||||||
const std::vector<double> &surrogate_values,
|
|
||||||
uint32_t size_per_prop,
|
|
||||||
uint32_t species_count,
|
|
||||||
SimulationErrorStats &species_error_stats);
|
|
||||||
|
|
||||||
protected:
|
|
||||||
void initializeDHT(uint32_t size_mb,
|
void initializeDHT(uint32_t size_mb,
|
||||||
const NamedVector<std::uint32_t> &key_species,
|
const NamedVector<std::uint32_t> &key_species,
|
||||||
bool has_het_ids);
|
bool has_het_ids);
|
||||||
@ -305,14 +268,13 @@ namespace poet
|
|||||||
std::uint32_t min_entries,
|
std::uint32_t min_entries,
|
||||||
const NamedVector<std::uint32_t> &key_species);
|
const NamedVector<std::uint32_t> &key_species);
|
||||||
|
|
||||||
enum
|
enum {
|
||||||
{
|
|
||||||
CHEM_FIELD_INIT,
|
CHEM_FIELD_INIT,
|
||||||
CHEM_DHT_ENABLE,
|
CHEM_DHT_ENABLE,
|
||||||
CHEM_DHT_SIGNIF_VEC,
|
CHEM_DHT_SIGNIF_VEC,
|
||||||
CHEM_DHT_SNAPS,
|
CHEM_DHT_SNAPS,
|
||||||
CHEM_DHT_READ_FILE,
|
CHEM_DHT_READ_FILE,
|
||||||
CHEM_INTERP,
|
CHEM_IP, // Control Flag
|
||||||
CHEM_IP_ENABLE,
|
CHEM_IP_ENABLE,
|
||||||
CHEM_IP_MIN_ENTRIES,
|
CHEM_IP_MIN_ENTRIES,
|
||||||
CHEM_IP_SIGNIF_VEC,
|
CHEM_IP_SIGNIF_VEC,
|
||||||
@ -322,15 +284,9 @@ namespace poet
|
|||||||
CHEM_AI_BCAST_VALIDITY
|
CHEM_AI_BCAST_VALIDITY
|
||||||
};
|
};
|
||||||
|
|
||||||
enum
|
enum { LOOP_WORK, LOOP_END, LOOP_CTRL };
|
||||||
{
|
|
||||||
LOOP_WORK,
|
|
||||||
LOOP_END,
|
|
||||||
LOOP_CTRL
|
|
||||||
};
|
|
||||||
|
|
||||||
enum
|
enum {
|
||||||
{
|
|
||||||
WORKER_PHREEQC,
|
WORKER_PHREEQC,
|
||||||
WORKER_CTRL_ITER,
|
WORKER_CTRL_ITER,
|
||||||
WORKER_DHT_GET,
|
WORKER_DHT_GET,
|
||||||
@ -350,8 +306,7 @@ namespace poet
|
|||||||
std::vector<uint32_t> dht_hits;
|
std::vector<uint32_t> dht_hits;
|
||||||
std::vector<uint32_t> dht_evictions;
|
std::vector<uint32_t> dht_evictions;
|
||||||
|
|
||||||
struct worker_s
|
struct worker_s {
|
||||||
{
|
|
||||||
double phreeqc_t = 0.;
|
double phreeqc_t = 0.;
|
||||||
double dht_get = 0.;
|
double dht_get = 0.;
|
||||||
double dht_fill = 0.;
|
double dht_fill = 0.;
|
||||||
@ -359,8 +314,7 @@ namespace poet
|
|||||||
double ctrl_t = 0.;
|
double ctrl_t = 0.;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct worker_info_s
|
struct worker_info_s {
|
||||||
{
|
|
||||||
char has_work = 0;
|
char has_work = 0;
|
||||||
double *send_addr;
|
double *send_addr;
|
||||||
double *surrogate_addr;
|
double *surrogate_addr;
|
||||||
@ -372,9 +326,10 @@ namespace poet
|
|||||||
void MasterRunParallel(double dt);
|
void MasterRunParallel(double dt);
|
||||||
void MasterRunSequential();
|
void MasterRunSequential();
|
||||||
|
|
||||||
void MasterSendPkgs(worker_list_t &w_list, workpointer_t &work_pointer, workpointer_t &sur_pointer,
|
void MasterSendPkgs(worker_list_t &w_list, workpointer_t &work_pointer,
|
||||||
int &pkg_to_send, int &count_pkgs, int &free_workers,
|
workpointer_t &sur_pointer, int &pkg_to_send,
|
||||||
double dt, uint32_t iteration, uint32_t control_iteration,
|
int &count_pkgs, int &free_workers, double dt,
|
||||||
|
uint32_t iteration, uint32_t control_iteration,
|
||||||
const std::vector<uint32_t> &wp_sizes_vector);
|
const std::vector<uint32_t> &wp_sizes_vector);
|
||||||
void MasterRecvPkgs(worker_list_t &w_list, int &pkg_to_recv, bool to_send,
|
void MasterRecvPkgs(worker_list_t &w_list, int &pkg_to_recv, bool to_send,
|
||||||
int &free_workers);
|
int &free_workers);
|
||||||
@ -433,13 +388,11 @@ namespace poet
|
|||||||
|
|
||||||
static constexpr uint32_t BUFFER_OFFSET = 6;
|
static constexpr uint32_t BUFFER_OFFSET = 6;
|
||||||
|
|
||||||
inline void ChemBCast(void *buf, int count, MPI_Datatype datatype) const
|
inline void ChemBCast(void *buf, int count, MPI_Datatype datatype) const {
|
||||||
{
|
|
||||||
MPI_Bcast(buf, count, datatype, 0, this->group_comm);
|
MPI_Bcast(buf, count, datatype, 0, this->group_comm);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void PropagateFunctionType(int &type) const
|
inline void PropagateFunctionType(int &type) const {
|
||||||
{
|
|
||||||
ChemBCast(&type, 1, MPI_INT);
|
ChemBCast(&type, 1, MPI_INT);
|
||||||
}
|
}
|
||||||
double simtime = 0.;
|
double simtime = 0.;
|
||||||
@ -469,8 +422,10 @@ namespace poet
|
|||||||
|
|
||||||
std::unique_ptr<PhreeqcRunner> pqc_runner;
|
std::unique_ptr<PhreeqcRunner> pqc_runner;
|
||||||
|
|
||||||
std::vector<double> sur_shuffled;
|
std::unique_ptr<poet::ControlModule> ctrl_module;
|
||||||
};
|
|
||||||
|
//std::vector<double> sur_shuffled;
|
||||||
|
};
|
||||||
} // namespace poet
|
} // namespace poet
|
||||||
|
|
||||||
#endif // CHEMISTRYMODULE_H_
|
#endif // CHEMISTRYMODULE_H_
|
||||||
|
|||||||
@ -3,7 +3,6 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <iomanip>
|
|
||||||
#include <mpi.h>
|
#include <mpi.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -166,39 +165,6 @@ std::vector<uint32_t> poet::ChemistryModule::GetWorkerPHTCacheHits() const {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
void poet::ChemistryModule::computeSpeciesErrors(const std::vector<double> &reference_values,
|
|
||||||
const std::vector<double> &surrogate_values,
|
|
||||||
uint32_t size_per_prop,
|
|
||||||
uint32_t species_count,
|
|
||||||
SimulationErrorStats &species_error_stats) {
|
|
||||||
for (uint32_t i = 0; i < species_count; ++i) {
|
|
||||||
double err_sum = 0.0;
|
|
||||||
double sqr_err_sum = 0.0;
|
|
||||||
uint32_t base_idx = i * size_per_prop;
|
|
||||||
|
|
||||||
for (uint32_t j = 0; j < size_per_prop; ++j) {
|
|
||||||
const double ref_value = reference_values[base_idx + j];
|
|
||||||
const double sur_value = surrogate_values[base_idx + j];
|
|
||||||
|
|
||||||
if (ref_value == 0.0) {
|
|
||||||
if (sur_value != 0.0) {
|
|
||||||
err_sum += 1.0;
|
|
||||||
sqr_err_sum += 1.0;
|
|
||||||
}
|
|
||||||
// Both zero: skip
|
|
||||||
} else {
|
|
||||||
double alpha = 1.0 - (sur_value / ref_value);
|
|
||||||
err_sum += std::abs(alpha);
|
|
||||||
sqr_err_sum += alpha * alpha;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
species_error_stats.mape[i] = 100.0 * (err_sum / size_per_prop);
|
|
||||||
species_error_stats.rrmse[i] =
|
|
||||||
(size_per_prop > 0) ? std::sqrt(sqr_err_sum / size_per_prop) : 0.0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline std::vector<int> shuffleVector(const std::vector<int> &in_vector,
|
inline std::vector<int> shuffleVector(const std::vector<int> &in_vector,
|
||||||
uint32_t size_per_prop,
|
uint32_t size_per_prop,
|
||||||
uint32_t wp_count) {
|
uint32_t wp_count) {
|
||||||
@ -269,8 +235,8 @@ inline void printProgressbar(int count_pkgs, int n_wp, int barWidth = 70) {
|
|||||||
inline void poet::ChemistryModule::MasterSendPkgs(
|
inline void poet::ChemistryModule::MasterSendPkgs(
|
||||||
worker_list_t &w_list, workpointer_t &work_pointer,
|
worker_list_t &w_list, workpointer_t &work_pointer,
|
||||||
workpointer_t &sur_pointer, int &pkg_to_send, int &count_pkgs,
|
workpointer_t &sur_pointer, int &pkg_to_send, int &count_pkgs,
|
||||||
int &free_workers, double dt, uint32_t iteration,
|
int &free_workers, double dt, uint32_t iteration, uint32_t control_interval,
|
||||||
uint32_t control_interval, const std::vector<uint32_t> &wp_sizes_vector) {
|
const std::vector<uint32_t> &wp_sizes_vector) {
|
||||||
/* declare variables */
|
/* declare variables */
|
||||||
int local_work_package_size;
|
int local_work_package_size;
|
||||||
|
|
||||||
@ -461,28 +427,9 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
|||||||
|
|
||||||
/* start time measurement of broadcasting interpolation status */
|
/* start time measurement of broadcasting interpolation status */
|
||||||
ctrl_bcast_a = MPI_Wtime();
|
ctrl_bcast_a = MPI_Wtime();
|
||||||
|
ftype = CHEM_IP;
|
||||||
ftype = CHEM_INTERP;
|
|
||||||
PropagateFunctionType(ftype);
|
PropagateFunctionType(ftype);
|
||||||
|
ctrl_module->BCastControlFlags();
|
||||||
int interp_flag = 0;
|
|
||||||
int dht_fill_flag = 0;
|
|
||||||
|
|
||||||
if(this->runtime_params->rollback_enabled){
|
|
||||||
this->interp_enabled = false;
|
|
||||||
this->dht_fill_during_rollback = true;
|
|
||||||
interp_flag = 0;
|
|
||||||
dht_fill_flag = 1;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
this->interp_enabled = true;
|
|
||||||
this->dht_fill_during_rollback = false;
|
|
||||||
interp_flag = 1;
|
|
||||||
dht_fill_flag = 0;
|
|
||||||
}
|
|
||||||
ChemBCast(&interp_flag, 1, MPI_INT);
|
|
||||||
ChemBCast(&dht_fill_flag, 1, MPI_INT);
|
|
||||||
|
|
||||||
/* end time measurement of broadcasting interpolation status */
|
/* end time measurement of broadcasting interpolation status */
|
||||||
ctrl_bcast_b = MPI_Wtime();
|
ctrl_bcast_b = MPI_Wtime();
|
||||||
this->bcast_ctrl_t += ctrl_bcast_b - ctrl_bcast_a;
|
this->bcast_ctrl_t += ctrl_bcast_b - ctrl_bcast_a;
|
||||||
@ -494,11 +441,12 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
|||||||
|
|
||||||
static uint32_t iteration = 0;
|
static uint32_t iteration = 0;
|
||||||
|
|
||||||
uint32_t control_logic_enabled = this->runtime_params->control_interval_enabled ? 1 : 0;
|
uint32_t control_logic_enabled =
|
||||||
|
ctrl_module->control_interval_enabled ? 1 : 0;
|
||||||
|
|
||||||
if (control_logic_enabled) {
|
if (control_logic_enabled) {
|
||||||
sur_shuffled.clear();
|
ctrl_module->sur_shuffled.clear();
|
||||||
sur_shuffled.reserve(this->n_cells * this->prop_count);
|
ctrl_module->sur_shuffled.reserve(this->n_cells * this->prop_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* start time measurement of sequential part */
|
/* start time measurement of sequential part */
|
||||||
@ -511,14 +459,14 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
|||||||
shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count,
|
shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count,
|
||||||
wp_sizes_vector.size());
|
wp_sizes_vector.size());
|
||||||
|
|
||||||
this->sur_shuffled.resize(mpi_buffer.size());
|
ctrl_module->sur_shuffled.resize(mpi_buffer.size());
|
||||||
|
|
||||||
/* setup local variables */
|
/* setup local variables */
|
||||||
pkg_to_send = wp_sizes_vector.size();
|
pkg_to_send = wp_sizes_vector.size();
|
||||||
pkg_to_recv = wp_sizes_vector.size();
|
pkg_to_recv = wp_sizes_vector.size();
|
||||||
|
|
||||||
workpointer_t work_pointer = mpi_buffer.begin();
|
workpointer_t work_pointer = mpi_buffer.begin();
|
||||||
workpointer_t sur_pointer = sur_shuffled.begin();
|
workpointer_t sur_pointer = ctrl_module->sur_shuffled.begin();
|
||||||
worker_list_t worker_list(this->comm_size - 1);
|
worker_list_t worker_list(this->comm_size - 1);
|
||||||
|
|
||||||
free_workers = this->comm_size - 1;
|
free_workers = this->comm_size - 1;
|
||||||
@ -571,25 +519,19 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
|||||||
/* start time measurement of control logic */
|
/* start time measurement of control logic */
|
||||||
ctrl_a = MPI_Wtime();
|
ctrl_a = MPI_Wtime();
|
||||||
|
|
||||||
if (control_logic_enabled && !this->runtime_params->rollback_enabled) {
|
if (control_logic_enabled && !ctrl_module->rollback_enabled) {
|
||||||
|
std::cout << "[Master] Control logic enabled for this iteration." << std::endl;
|
||||||
std::vector<double> sur_unshuffled{sur_shuffled};;
|
std::vector<double> sur_unshuffled{ctrl_module->sur_shuffled};
|
||||||
|
unshuffleField(ctrl_module->sur_shuffled, this->n_cells, this->prop_count,
|
||||||
unshuffleField(sur_shuffled, this->n_cells, this->prop_count,
|
|
||||||
wp_sizes_vector.size(), sur_unshuffled);
|
wp_sizes_vector.size(), sur_unshuffled);
|
||||||
|
|
||||||
SimulationErrorStats stats(this->prop_count, this->runtime_params->global_iter, this->runtime_params->rollback_counter);
|
ctrl_module->computeSpeciesErrors(out_vec, sur_unshuffled, this->n_cells);
|
||||||
|
|
||||||
computeSpeciesErrors(out_vec, sur_unshuffled, this->n_cells, this->prop_count, stats);
|
|
||||||
|
|
||||||
error_history.push_back(stats);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* end time measurement of control logic */
|
/* end time measurement of control logic */
|
||||||
ctrl_b = MPI_Wtime();
|
ctrl_b = MPI_Wtime();
|
||||||
this->ctrl_t += ctrl_b - ctrl_a;
|
this->ctrl_t += ctrl_b - ctrl_a;
|
||||||
|
|
||||||
|
|
||||||
/* start time measurement of master chemistry */
|
/* start time measurement of master chemistry */
|
||||||
sim_e_chemistry = MPI_Wtime();
|
sim_e_chemistry = MPI_Wtime();
|
||||||
|
|
||||||
|
|||||||
@ -67,7 +67,7 @@ namespace poet
|
|||||||
MPI_INT, 0, this->group_comm);
|
MPI_INT, 0, this->group_comm);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case CHEM_INTERP:
|
case CHEM_IP:
|
||||||
{
|
{
|
||||||
int interp_flag = 0;
|
int interp_flag = 0;
|
||||||
int dht_fill_flag = 0;
|
int dht_fill_flag = 0;
|
||||||
|
|||||||
131
src/Control/ControlModule.cpp
Normal file
131
src/Control/ControlModule.cpp
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
#include "ControlModule.hpp"
|
||||||
|
#include "IO/Datatypes.hpp"
|
||||||
|
#include "IO/HDF5Functions.hpp"
|
||||||
|
#include "IO/StatsIO.hpp"
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
bool poet::ControlModule::isControlIteration(uint32_t iter) {
|
||||||
|
control_interval_enabled = (iter % control_interval == 0);
|
||||||
|
if (control_interval_enabled) {
|
||||||
|
MSG("[Control] Control interval triggered at iteration " +
|
||||||
|
std::to_string(iter));
|
||||||
|
}
|
||||||
|
return control_interval_enabled;
|
||||||
|
}
|
||||||
|
|
||||||
|
void poet::ControlModule::beginIteration() {
|
||||||
|
if (rollback_enabled) {
|
||||||
|
if (sur_disabled_counter > 0) {
|
||||||
|
sur_disabled_counter--;
|
||||||
|
MSG("Rollback counter: " + std::to_string(sur_disabled_counter));
|
||||||
|
} else {
|
||||||
|
rollback_enabled = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void poet::ControlModule::endIteration(uint32_t iter) {
|
||||||
|
/* Writing a checkpointing */
|
||||||
|
if (checkpoint_interval > 0 && iter % checkpoint_interval == 0) {
|
||||||
|
MSG("Writing checkpoint of iteration " + std::to_string(iter));
|
||||||
|
write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
|
||||||
|
{.field = chem->getField(), .iteration = iter});
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Control Logic*/
|
||||||
|
if (control_interval_enabled && !rollback_enabled) {
|
||||||
|
writeStatsToCSV(error_history, species_names, out_dir,
|
||||||
|
"stats_overview");
|
||||||
|
|
||||||
|
if (triggerRollbackIfExceeded(*chem, *params, iter)) {
|
||||||
|
rollback_enabled = true;
|
||||||
|
rollback_counter++;
|
||||||
|
sur_disabled_counter = control_interval;
|
||||||
|
MSG("Interpolation disabled for the next " +
|
||||||
|
std::to_string(control_interval) + ".");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void poet::ControlModule::BCastControlFlags() {
|
||||||
|
int interp_flag = rollback_enabled ? 0 : 1;
|
||||||
|
int dht_fill_flag = rollback_enabled ? 1 : 0;
|
||||||
|
chem->ChemBCast(&interp_flag, 1, MPI_INT);
|
||||||
|
chem->ChemBCast(&dht_fill_flag, 1, MPI_INT);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool poet::ControlModule::triggerRollbackIfExceeded(ChemistryModule &chem,
|
||||||
|
RuntimeParameters ¶ms,
|
||||||
|
uint32_t &iter) {
|
||||||
|
|
||||||
|
if (error_history.empty()) {
|
||||||
|
MSG("No error history yet; skipping rollback check.");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto &mape = chem.error_history.back().mape;
|
||||||
|
const auto &props = chem.getField().GetProps();
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < params.mape_threshold.size(); ++i) {
|
||||||
|
// Skip invalid entries
|
||||||
|
if (mape[i] == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
bool mape_exceeded = mape[i] > params.mape_threshold[i];
|
||||||
|
|
||||||
|
if (mape_exceeded) {
|
||||||
|
uint32_t rollback_iter = ((iter - 1) / params.checkpoint_interval) *
|
||||||
|
params.checkpoint_interval;
|
||||||
|
|
||||||
|
MSG("[THRESHOLD EXCEEDED] " + props[i] +
|
||||||
|
" has MAPE = " + std::to_string(mape[i]) +
|
||||||
|
" exceeding threshold = " + std::to_string(params.mape_threshold[i]) +
|
||||||
|
" → rolling back to iteration " + std::to_string(rollback_iter));
|
||||||
|
|
||||||
|
Checkpoint_s checkpoint_read{.field = chem.getField()};
|
||||||
|
read_checkpoint(params.out_dir,
|
||||||
|
"checkpoint" + std::to_string(rollback_iter) + ".hdf5",
|
||||||
|
checkpoint_read);
|
||||||
|
iter = checkpoint_read.iteration;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MSG("All species are within their MAPE and RRMSE thresholds.");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void poet::ControlModule::computeSpeciesErrors(
|
||||||
|
const std::vector<double> &reference_values,
|
||||||
|
const std::vector<double> &surrogate_values, uint32_t size_per_prop) {
|
||||||
|
|
||||||
|
SimulationErrorStats species_error_stats(species_count, params->global_iter,
|
||||||
|
rollback_counter);
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < species_count; ++i) {
|
||||||
|
double err_sum = 0.0;
|
||||||
|
double sqr_err_sum = 0.0;
|
||||||
|
uint32_t base_idx = i * size_per_prop;
|
||||||
|
|
||||||
|
for (uint32_t j = 0; j < size_per_prop; ++j) {
|
||||||
|
const double ref_value = reference_values[base_idx + j];
|
||||||
|
const double sur_value = surrogate_values[base_idx + j];
|
||||||
|
|
||||||
|
if (ref_value == 0.0) {
|
||||||
|
if (sur_value != 0.0) {
|
||||||
|
err_sum += 1.0;
|
||||||
|
sqr_err_sum += 1.0;
|
||||||
|
}
|
||||||
|
// Both zero: skip
|
||||||
|
} else {
|
||||||
|
double alpha = 1.0 - (sur_value / ref_value);
|
||||||
|
err_sum += std::abs(alpha);
|
||||||
|
sqr_err_sum += alpha * alpha;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
species_error_stats.mape[i] = 100.0 * (err_sum / size_per_prop);
|
||||||
|
species_error_stats.rrmse[i] =
|
||||||
|
(size_per_prop > 0) ? std::sqrt(sqr_err_sum / size_per_prop) : 0.0;
|
||||||
|
}
|
||||||
|
error_history.push_back(species_error_stats);
|
||||||
|
}
|
||||||
110
src/Control/ControlModule.hpp
Normal file
110
src/Control/ControlModule.hpp
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
#ifndef CONTROLMODULE_H_
|
||||||
|
#define CONTROLMODULE_H_
|
||||||
|
|
||||||
|
#include "Base/Macros.hpp"
|
||||||
|
#include "Chemistry/ChemistryModule.hpp"
|
||||||
|
#include "poet.hpp"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace poet {
|
||||||
|
|
||||||
|
class ChemistryModule;
|
||||||
|
|
||||||
|
class ControlModule {
|
||||||
|
|
||||||
|
public:
|
||||||
|
ControlModule(RuntimeParameters *run_params, ChemistryModule *chem_module)
|
||||||
|
: params(run_params), chem(chem_module) {};
|
||||||
|
|
||||||
|
/* Control configuration*/
|
||||||
|
std::vector<std::string> species_names;
|
||||||
|
uint32_t species_count = 0;
|
||||||
|
std::string out_dir;
|
||||||
|
|
||||||
|
bool rollback_enabled = false;
|
||||||
|
bool control_interval_enabled = false;
|
||||||
|
|
||||||
|
std::uint32_t global_iter = 0;
|
||||||
|
std::uint32_t sur_disabled_counter = 0;
|
||||||
|
std::uint32_t rollback_counter = 0;
|
||||||
|
std::uint32_t checkpoint_interval = 0;
|
||||||
|
std::uint32_t control_interval = 0;
|
||||||
|
|
||||||
|
std::vector<double> mape_threshold;
|
||||||
|
std::vector<double> rrmse_threshold;
|
||||||
|
|
||||||
|
double ctrl_t = 0.;
|
||||||
|
double bcast_ctrl_t = 0.;
|
||||||
|
double recv_ctrl_t = 0.;
|
||||||
|
|
||||||
|
/* Buffer for shuffled surrogate data */
|
||||||
|
std::vector<double> sur_shuffled;
|
||||||
|
|
||||||
|
bool isControlIteration(uint32_t iter);
|
||||||
|
|
||||||
|
void beginIteration();
|
||||||
|
|
||||||
|
void endIteration(uint32_t iter);
|
||||||
|
|
||||||
|
void BCastControlFlags();
|
||||||
|
|
||||||
|
bool triggerRollbackIfExceeded(ChemistryModule &chem,
|
||||||
|
RuntimeParameters ¶ms, uint32_t &iter);
|
||||||
|
|
||||||
|
struct SimulationErrorStats {
|
||||||
|
std::vector<double> mape;
|
||||||
|
std::vector<double> rrmse;
|
||||||
|
uint32_t iteration; // iterations in simulation after rollbacks
|
||||||
|
uint32_t rollback_count;
|
||||||
|
|
||||||
|
SimulationErrorStats(size_t species_count, uint32_t iter, uint32_t counter)
|
||||||
|
: mape(species_count, 0.0), rrmse(species_count, 0.0), iteration(iter),
|
||||||
|
rollback_count(counter) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
static void computeSpeciesErrors(const std::vector<double> &reference_values,
|
||||||
|
const std::vector<double> &surrogate_values,
|
||||||
|
uint32_t size_per_prop);
|
||||||
|
|
||||||
|
std::vector<SimulationErrorStats> error_history;
|
||||||
|
|
||||||
|
struct ControlSetup {
|
||||||
|
std::string out_dir;
|
||||||
|
std::uint32_t checkpoint_interval;
|
||||||
|
std::uint32_t control_interval;
|
||||||
|
std::uint32_t species_count;
|
||||||
|
|
||||||
|
std::vector<std::string> species_names;
|
||||||
|
std::vector<double> mape_threshold;
|
||||||
|
std::vector<double> rrmse_threshold;
|
||||||
|
};
|
||||||
|
|
||||||
|
void enableControlLogic(const ControlSetup &setup) {
|
||||||
|
out_dir = setup.out_dir;
|
||||||
|
checkpoint_interval = setup.checkpoint_interval;
|
||||||
|
control_interval = setup.control_interval;
|
||||||
|
species_count = setup.species_count;
|
||||||
|
|
||||||
|
species_names = setup.species_names;
|
||||||
|
mape_threshold = setup.mape_threshold;
|
||||||
|
rrmse_threshold = setup.rrmse_threshold;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Profiling getters */
|
||||||
|
auto GetMasterCtrlLogicTime() const { return this->ctrl_t; }
|
||||||
|
|
||||||
|
auto GetMasterCtrlBcastTime() const { return this->bcast_ctrl_t; }
|
||||||
|
|
||||||
|
auto GetMasterRecvCtrlLogicTime() const { return this->recv_ctrl_t; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
RuntimeParameters *params;
|
||||||
|
ChemistryModule *chem;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace poet
|
||||||
|
|
||||||
|
#endif // CONTROLMODULE_H_
|
||||||
@ -7,7 +7,7 @@
|
|||||||
|
|
||||||
namespace poet
|
namespace poet
|
||||||
{
|
{
|
||||||
void writeStatsToCSV(const std::vector<ChemistryModule::SimulationErrorStats> &all_stats,
|
void writeStatsToCSV(const std::vector<ControlModule::SimulationErrorStats> &all_stats,
|
||||||
const std::vector<std::string> &species_names,
|
const std::vector<std::string> &species_names,
|
||||||
const std::string &out_dir,
|
const std::string &out_dir,
|
||||||
const std::string &filename)
|
const std::string &filename)
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include "Chemistry/ChemistryModule.hpp"
|
#include "Control/ControlModule.hpp"
|
||||||
|
|
||||||
namespace poet
|
namespace poet
|
||||||
{
|
{
|
||||||
void writeStatsToCSV(const std::vector<ChemistryModule::SimulationErrorStats> &all_stats,
|
void writeStatsToCSV(const std::vector<ControlModule::SimulationErrorStats> &all_stats,
|
||||||
const std::vector<std::string> &species_names,
|
const std::vector<std::string> &species_names,
|
||||||
const std::string &out_dir,
|
const std::string &out_dir,
|
||||||
const std::string &filename);
|
const std::string &filename);
|
||||||
|
|||||||
207
src/poet.cpp
207
src/poet.cpp
@ -25,10 +25,8 @@
|
|||||||
#include "Base/RInsidePOET.hpp"
|
#include "Base/RInsidePOET.hpp"
|
||||||
#include "CLI/CLI.hpp"
|
#include "CLI/CLI.hpp"
|
||||||
#include "Chemistry/ChemistryModule.hpp"
|
#include "Chemistry/ChemistryModule.hpp"
|
||||||
|
#include "Control/ControlManager.hpp"
|
||||||
#include "DataStructures/Field.hpp"
|
#include "DataStructures/Field.hpp"
|
||||||
#include "IO/Datatypes.hpp"
|
|
||||||
#include "IO/HDF5Functions.hpp"
|
|
||||||
#include "IO/StatsIO.hpp"
|
|
||||||
#include "Init/InitialList.hpp"
|
#include "Init/InitialList.hpp"
|
||||||
#include "Transport/DiffusionModule.hpp"
|
#include "Transport/DiffusionModule.hpp"
|
||||||
|
|
||||||
@ -68,8 +66,7 @@ static poet::DEFunc ReadRObj_R;
|
|||||||
static poet::DEFunc SaveRObj_R;
|
static poet::DEFunc SaveRObj_R;
|
||||||
static poet::DEFunc source_R;
|
static poet::DEFunc source_R;
|
||||||
|
|
||||||
static void init_global_functions(RInside &R)
|
static void init_global_functions(RInside &R) {
|
||||||
{
|
|
||||||
R.parseEval(kin_r_library);
|
R.parseEval(kin_r_library);
|
||||||
master_init_R = DEFunc("master_init");
|
master_init_R = DEFunc("master_init");
|
||||||
master_iteration_end_R = DEFunc("master_iteration_end");
|
master_iteration_end_R = DEFunc("master_iteration_end");
|
||||||
@ -92,15 +89,9 @@ static void init_global_functions(RInside &R)
|
|||||||
// R.parseEval("mysetup$state_C <- TMP");
|
// R.parseEval("mysetup$state_C <- TMP");
|
||||||
// }
|
// }
|
||||||
|
|
||||||
enum ParseRet
|
enum ParseRet { PARSER_OK, PARSER_ERROR, PARSER_HELP };
|
||||||
{
|
|
||||||
PARSER_OK,
|
|
||||||
PARSER_ERROR,
|
|
||||||
PARSER_HELP
|
|
||||||
};
|
|
||||||
|
|
||||||
int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms)
|
int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) {
|
||||||
{
|
|
||||||
|
|
||||||
CLI::App app{"POET - Potsdam rEactive Transport simulator"};
|
CLI::App app{"POET - Potsdam rEactive Transport simulator"};
|
||||||
|
|
||||||
@ -182,12 +173,9 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms)
|
|||||||
"Output directory of the simulation")
|
"Output directory of the simulation")
|
||||||
->required();
|
->required();
|
||||||
|
|
||||||
try
|
try {
|
||||||
{
|
|
||||||
app.parse(argc, argv);
|
app.parse(argc, argv);
|
||||||
}
|
} catch (const CLI::ParseError &e) {
|
||||||
catch (const CLI::ParseError &e)
|
|
||||||
{
|
|
||||||
app.exit(e);
|
app.exit(e);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
@ -199,16 +187,14 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms)
|
|||||||
if (params.as_qs)
|
if (params.as_qs)
|
||||||
params.out_ext = "qs";
|
params.out_ext = "qs";
|
||||||
|
|
||||||
if (MY_RANK == 0)
|
if (MY_RANK == 0) {
|
||||||
{
|
|
||||||
// MSG("Complete results storage is " + BOOL_PRINT(simparams.store_result));
|
// MSG("Complete results storage is " + BOOL_PRINT(simparams.store_result));
|
||||||
MSG("Output format/extension is " + params.out_ext);
|
MSG("Output format/extension is " + params.out_ext);
|
||||||
MSG("Work Package Size: " + std::to_string(params.work_package_size));
|
MSG("Work Package Size: " + std::to_string(params.work_package_size));
|
||||||
MSG("DHT is " + BOOL_PRINT(params.use_dht));
|
MSG("DHT is " + BOOL_PRINT(params.use_dht));
|
||||||
MSG("AI Surrogate is " + BOOL_PRINT(params.use_ai_surrogate));
|
MSG("AI Surrogate is " + BOOL_PRINT(params.use_ai_surrogate));
|
||||||
|
|
||||||
if (params.use_dht)
|
if (params.use_dht) {
|
||||||
{
|
|
||||||
// MSG("DHT strategy is " + std::to_string(simparams.dht_strategy));
|
// MSG("DHT strategy is " + std::to_string(simparams.dht_strategy));
|
||||||
// MDL: these should be outdated (?)
|
// MDL: these should be outdated (?)
|
||||||
// MSG("DHT key default digits (ignored if 'signif_vector' is "
|
// MSG("DHT key default digits (ignored if 'signif_vector' is "
|
||||||
@ -222,8 +208,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms)
|
|||||||
// MSG("DHT load file is " + chem_params.dht_file);
|
// MSG("DHT load file is " + chem_params.dht_file);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.use_interp)
|
if (params.use_interp) {
|
||||||
{
|
|
||||||
MSG("PHT interpolation enabled: " + BOOL_PRINT(params.use_interp));
|
MSG("PHT interpolation enabled: " + BOOL_PRINT(params.use_interp));
|
||||||
MSG("PHT interp-size = " + std::to_string(params.interp_size));
|
MSG("PHT interp-size = " + std::to_string(params.interp_size));
|
||||||
MSG("PHT interp-min = " + std::to_string(params.interp_min_entries));
|
MSG("PHT interp-min = " + std::to_string(params.interp_min_entries));
|
||||||
@ -251,8 +236,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms)
|
|||||||
// // log before rounding?
|
// // log before rounding?
|
||||||
// R["dht_log"] = simparams.dht_log;
|
// R["dht_log"] = simparams.dht_log;
|
||||||
|
|
||||||
try
|
try {
|
||||||
{
|
|
||||||
Rcpp::List init_params_(ReadRObj_R(init_file));
|
Rcpp::List init_params_(ReadRObj_R(init_file));
|
||||||
params.init_params = init_params_;
|
params.init_params = init_params_;
|
||||||
|
|
||||||
@ -269,13 +253,11 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms)
|
|||||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("control_interval"));
|
Rcpp::as<uint32_t>(global_rt_setup->operator[]("control_interval"));
|
||||||
params.checkpoint_interval =
|
params.checkpoint_interval =
|
||||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval"));
|
Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval"));
|
||||||
params.mape_threshold =
|
params.mape_threshold = Rcpp::as<std::vector<double>>(
|
||||||
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("mape_threshold"));
|
global_rt_setup->operator[]("mape_threshold"));
|
||||||
params.rrmse_threshold =
|
params.rrmse_threshold = Rcpp::as<std::vector<double>>(
|
||||||
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("rrmse_threshold"));
|
global_rt_setup->operator[]("rrmse_threshold"));
|
||||||
}
|
} catch (const std::exception &e) {
|
||||||
catch (const std::exception &e)
|
|
||||||
{
|
|
||||||
ERRMSG("Error while parsing R scripts: " + std::string(e.what()));
|
ERRMSG("Error while parsing R scripts: " + std::string(e.what()));
|
||||||
return ParseRet::PARSER_ERROR;
|
return ParseRet::PARSER_ERROR;
|
||||||
}
|
}
|
||||||
@ -285,8 +267,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms)
|
|||||||
|
|
||||||
// HACK: this is a step back as the order and also the count of fields is
|
// HACK: this is a step back as the order and also the count of fields is
|
||||||
// predefined, but it will change in the future
|
// predefined, but it will change in the future
|
||||||
void call_master_iter_end(RInside &R, const Field &trans, const Field &chem)
|
void call_master_iter_end(RInside &R, const Field &trans, const Field &chem) {
|
||||||
{
|
|
||||||
R["TMP"] = Rcpp::wrap(trans.AsVector());
|
R["TMP"] = Rcpp::wrap(trans.AsVector());
|
||||||
R["TMP_PROPS"] = Rcpp::wrap(trans.GetProps());
|
R["TMP_PROPS"] = Rcpp::wrap(trans.GetProps());
|
||||||
R.parseEval(std::string("state_T <- setNames(data.frame(matrix(TMP, nrow=" +
|
R.parseEval(std::string("state_T <- setNames(data.frame(matrix(TMP, nrow=" +
|
||||||
@ -303,53 +284,15 @@ void call_master_iter_end(RInside &R, const Field &trans, const Field &chem)
|
|||||||
*global_rt_setup = R["setup"];
|
*global_rt_setup = R["setup"];
|
||||||
}
|
}
|
||||||
|
|
||||||
bool triggerRollbackIfExceeded(ChemistryModule &chem, RuntimeParameters ¶ms, uint32_t ¤t_iteration)
|
|
||||||
{
|
|
||||||
const auto &mape = chem.error_history.back().mape;
|
|
||||||
const auto &rrmse = chem.error_history.back().rrmse;
|
|
||||||
const auto &props = chem.getField().GetProps();
|
|
||||||
|
|
||||||
for (uint32_t i = 0; i < params.mape_threshold.size(); ++i)
|
|
||||||
{
|
|
||||||
// Skip invalid entries
|
|
||||||
if ((mape[i] == 0 && rrmse[i] == 0))
|
|
||||||
continue;
|
|
||||||
|
|
||||||
bool mape_exceeded = mape[i] > params.mape_threshold[i];
|
|
||||||
bool rrmse_exceeded = rrmse[i] > params.rrmse_threshold[i];
|
|
||||||
|
|
||||||
if (mape_exceeded || rrmse_exceeded)
|
|
||||||
{
|
|
||||||
uint32_t rollback_iter = ((current_iteration - 1) / params.checkpoint_interval) * params.checkpoint_interval;
|
|
||||||
std::string metric = mape_exceeded ? "MAPE" : "RRMSE";
|
|
||||||
double value = mape_exceeded ? mape[i] : rrmse[i];
|
|
||||||
double threshold = mape_exceeded ? params.mape_threshold[i] : params.rrmse_threshold[i];
|
|
||||||
|
|
||||||
MSG("[THRESHOLD EXCEEDED] " + props[i] + " has " + metric + " = " +
|
|
||||||
std::to_string(value) + " exceeding threshold = " + std::to_string(threshold) +
|
|
||||||
" → rolling back to iteration " + std::to_string(rollback_iter));
|
|
||||||
|
|
||||||
Checkpoint_s checkpoint_read{.field = chem.getField()};
|
|
||||||
read_checkpoint(params.out_dir, "checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read);
|
|
||||||
current_iteration = checkpoint_read.iteration;
|
|
||||||
return true; // rollback happened
|
|
||||||
}
|
|
||||||
}
|
|
||||||
MSG("All species are within their MAPE and RRMSE thresholds.");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||||
DiffusionModule &diffusion,
|
DiffusionModule &diffusion,
|
||||||
ChemistryModule &chem)
|
ChemistryModule &chem, ControlModule &control) {
|
||||||
{
|
|
||||||
|
|
||||||
/* Iteration Count is dynamic, retrieving value from R (is only needed by
|
/* Iteration Count is dynamic, retrieving value from R (is only needed by
|
||||||
* master for the following loop) */
|
* master for the following loop) */
|
||||||
uint32_t maxiter = params.timesteps.size();
|
uint32_t maxiter = params.timesteps.size();
|
||||||
|
|
||||||
if (params.print_progress)
|
if (params.print_progress) {
|
||||||
{
|
|
||||||
chem.setProgressBarPrintout(true);
|
chem.setProgressBarPrintout(true);
|
||||||
}
|
}
|
||||||
R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps());
|
R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps());
|
||||||
@ -359,9 +302,10 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
|||||||
double dSimTime{0};
|
double dSimTime{0};
|
||||||
double chkTime = 0.0;
|
double chkTime = 0.0;
|
||||||
|
|
||||||
for (uint32_t iter = 1; iter < maxiter + 1; iter++)
|
for (uint32_t iter = 1; iter < maxiter + 1; iter++) {
|
||||||
{
|
|
||||||
// Rollback countdowm
|
// Rollback countdowm
|
||||||
|
|
||||||
|
/*
|
||||||
if (params.rollback_enabled) {
|
if (params.rollback_enabled) {
|
||||||
if (params.sur_disabled_counter > 0) {
|
if (params.sur_disabled_counter > 0) {
|
||||||
--params.sur_disabled_counter;
|
--params.sur_disabled_counter;
|
||||||
@ -370,9 +314,12 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
|||||||
params.rollback_enabled = false;
|
params.rollback_enabled = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
control.beginIteration(iter);
|
||||||
|
|
||||||
params.global_iter = iter;
|
// params.global_iter = iter;
|
||||||
params.control_interval_enabled = (iter % params.control_interval == 0);
|
control.isControlIteration(iter);
|
||||||
|
// params.control_interval_enabled = (iter % params.control_interval == 0);
|
||||||
|
|
||||||
double start_t = MPI_Wtime();
|
double start_t = MPI_Wtime();
|
||||||
|
|
||||||
@ -389,13 +336,12 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
|||||||
/* run transport */
|
/* run transport */
|
||||||
diffusion.simulate(dt);
|
diffusion.simulate(dt);
|
||||||
|
|
||||||
chem.runtime_params = ¶ms;
|
// chem.runtime_params = ¶ms;
|
||||||
|
|
||||||
chem.getField().update(diffusion.getField());
|
chem.getField().update(diffusion.getField());
|
||||||
|
|
||||||
// MSG("Chemistry start");
|
// MSG("Chemistry start");
|
||||||
if (params.use_ai_surrogate)
|
if (params.use_ai_surrogate) {
|
||||||
{
|
|
||||||
double ai_start_t = MPI_Wtime();
|
double ai_start_t = MPI_Wtime();
|
||||||
// Save current values from the tug field as predictor for the ai step
|
// Save current values from the tug field as predictor for the ai step
|
||||||
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
|
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
|
||||||
@ -446,8 +392,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
|||||||
chem.simulate(dt);
|
chem.simulate(dt);
|
||||||
|
|
||||||
/* AI surrogate iterative training*/
|
/* AI surrogate iterative training*/
|
||||||
if (params.use_ai_surrogate)
|
if (params.use_ai_surrogate) {
|
||||||
{
|
|
||||||
double ai_start_t = MPI_Wtime();
|
double ai_start_t = MPI_Wtime();
|
||||||
|
|
||||||
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
|
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
|
||||||
@ -487,24 +432,31 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
|||||||
std::to_string(maxiter));
|
std::to_string(maxiter));
|
||||||
|
|
||||||
double chk_start = MPI_Wtime();
|
double chk_start = MPI_Wtime();
|
||||||
|
control.endIteration(iter)
|
||||||
if(iter % params.checkpoint_interval == 0){
|
/*
|
||||||
|
if (iter % params.checkpoint_interval == 0) {
|
||||||
MSG("Writing checkpoint of iteration " + std::to_string(iter));
|
MSG("Writing checkpoint of iteration " + std::to_string(iter));
|
||||||
write_checkpoint(params.out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
|
write_checkpoint(params.out_dir,
|
||||||
|
"checkpoint" + std::to_string(iter) + ".hdf5",
|
||||||
{.field = chem.getField(), .iteration = iter});
|
{.field = chem.getField(), .iteration = iter});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.control_interval_enabled && !params.rollback_enabled)
|
|
||||||
{
|
|
||||||
writeStatsToCSV(chem.error_history, chem.getField().GetProps(), params.out_dir,"stats_overview");
|
|
||||||
|
|
||||||
if(triggerRollbackIfExceeded(chem, params, iter)){
|
if (params.control_interval_enabled && !params.rollback_enabled) {
|
||||||
|
writeStatsToCSV(chem.error_history, chem.getField().GetProps(),
|
||||||
|
params.out_dir, "stats_overview");
|
||||||
|
|
||||||
|
if (triggerRollbackIfExceeded(chem, params, iter)) {
|
||||||
params.rollback_enabled = true;
|
params.rollback_enabled = true;
|
||||||
params.rollback_counter ++;
|
params.rollback_counter++;
|
||||||
params.sur_disabled_counter = params.control_interval;
|
params.sur_disabled_counter = params.control_interval;
|
||||||
MSG("Interpolation disabled for the next " + std::to_string(params.control_interval) + ".");
|
MSG("Interpolation disabled for the next " +
|
||||||
|
std::to_string(params.control_interval) + ".");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
double chk_end = MPI_Wtime();
|
double chk_end = MPI_Wtime();
|
||||||
chkTime += chk_end - chk_start;
|
chkTime += chk_end - chk_start;
|
||||||
|
|
||||||
@ -529,10 +481,10 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
|||||||
ctrl_profiling["ctrl_logic_master"] = chem.GetMasterCtrlLogicTime();
|
ctrl_profiling["ctrl_logic_master"] = chem.GetMasterCtrlLogicTime();
|
||||||
ctrl_profiling["bcast_ctrl_logic_master"] = chem.GetMasterCtrlBcastTime();
|
ctrl_profiling["bcast_ctrl_logic_master"] = chem.GetMasterCtrlBcastTime();
|
||||||
ctrl_profiling["recv_ctrl_logic_maser"] = chem.GetMasterRecvCtrlLogicTime();
|
ctrl_profiling["recv_ctrl_logic_maser"] = chem.GetMasterRecvCtrlLogicTime();
|
||||||
ctrl_profiling["ctrl_logic_worker"] = Rcpp::wrap(chem.GetWorkerControlTimings());
|
ctrl_profiling["ctrl_logic_worker"] =
|
||||||
|
Rcpp::wrap(chem.GetWorkerControlTimings());
|
||||||
|
|
||||||
if (params.use_dht)
|
if (params.use_dht) {
|
||||||
{
|
|
||||||
chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits());
|
chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits());
|
||||||
chem_profiling["dht_evictions"] = Rcpp::wrap(chem.GetWorkerDHTEvictions());
|
chem_profiling["dht_evictions"] = Rcpp::wrap(chem.GetWorkerDHTEvictions());
|
||||||
chem_profiling["dht_get_time"] = Rcpp::wrap(chem.GetWorkerDHTGetTimings());
|
chem_profiling["dht_get_time"] = Rcpp::wrap(chem.GetWorkerDHTGetTimings());
|
||||||
@ -540,8 +492,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
|||||||
Rcpp::wrap(chem.GetWorkerDHTFillTimings());
|
Rcpp::wrap(chem.GetWorkerDHTFillTimings());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.use_interp)
|
if (params.use_interp) {
|
||||||
{
|
|
||||||
chem_profiling["interp_w"] =
|
chem_profiling["interp_w"] =
|
||||||
Rcpp::wrap(chem.GetWorkerInterpolationWriteTimings());
|
Rcpp::wrap(chem.GetWorkerInterpolationWriteTimings());
|
||||||
chem_profiling["interp_r"] =
|
chem_profiling["interp_r"] =
|
||||||
@ -561,15 +512,13 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
|||||||
profiling["diffusion"] = diffusion_profiling;
|
profiling["diffusion"] = diffusion_profiling;
|
||||||
profiling["ctrl_logic"] = ctrl_profiling;
|
profiling["ctrl_logic"] = ctrl_profiling;
|
||||||
|
|
||||||
|
|
||||||
chem.MasterLoopBreak();
|
chem.MasterLoopBreak();
|
||||||
|
|
||||||
return profiling;
|
return profiling;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> getSpeciesNames(const Field &&field, int root,
|
std::vector<std::string> getSpeciesNames(const Field &&field, int root,
|
||||||
MPI_Comm comm)
|
MPI_Comm comm) {
|
||||||
{
|
|
||||||
std::uint32_t n_elements;
|
std::uint32_t n_elements;
|
||||||
std::uint32_t n_string_size;
|
std::uint32_t n_string_size;
|
||||||
|
|
||||||
@ -579,13 +528,11 @@ std::vector<std::string> getSpeciesNames(const Field &&field, int root,
|
|||||||
const bool is_master = root == rank;
|
const bool is_master = root == rank;
|
||||||
|
|
||||||
// first, the master sends all the species names iterative
|
// first, the master sends all the species names iterative
|
||||||
if (is_master)
|
if (is_master) {
|
||||||
{
|
|
||||||
n_elements = field.GetProps().size();
|
n_elements = field.GetProps().size();
|
||||||
MPI_Bcast(&n_elements, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
|
MPI_Bcast(&n_elements, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
|
||||||
|
|
||||||
for (std::uint32_t i = 0; i < n_elements; i++)
|
for (std::uint32_t i = 0; i < n_elements; i++) {
|
||||||
{
|
|
||||||
n_string_size = field.GetProps()[i].size();
|
n_string_size = field.GetProps()[i].size();
|
||||||
MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
|
MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
|
||||||
MPI_Bcast(const_cast<char *>(field.GetProps()[i].c_str()), n_string_size,
|
MPI_Bcast(const_cast<char *>(field.GetProps()[i].c_str()), n_string_size,
|
||||||
@ -600,8 +547,7 @@ std::vector<std::string> getSpeciesNames(const Field &&field, int root,
|
|||||||
|
|
||||||
std::vector<std::string> species_names_out(n_elements);
|
std::vector<std::string> species_names_out(n_elements);
|
||||||
|
|
||||||
for (std::uint32_t i = 0; i < n_elements; i++)
|
for (std::uint32_t i = 0; i < n_elements; i++) {
|
||||||
{
|
|
||||||
MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
|
MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
|
||||||
|
|
||||||
char recv_buf[n_string_size];
|
char recv_buf[n_string_size];
|
||||||
@ -614,8 +560,7 @@ std::vector<std::string> getSpeciesNames(const Field &&field, int root,
|
|||||||
return species_names_out;
|
return species_names_out;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::array<double, 2> getBaseTotals(Field &&field, int root, MPI_Comm comm)
|
std::array<double, 2> getBaseTotals(Field &&field, int root, MPI_Comm comm) {
|
||||||
{
|
|
||||||
std::array<double, 2> base_totals;
|
std::array<double, 2> base_totals;
|
||||||
|
|
||||||
int rank;
|
int rank;
|
||||||
@ -623,8 +568,7 @@ std::array<double, 2> getBaseTotals(Field &&field, int root, MPI_Comm comm)
|
|||||||
|
|
||||||
const bool is_master = root == rank;
|
const bool is_master = root == rank;
|
||||||
|
|
||||||
if (is_master)
|
if (is_master) {
|
||||||
{
|
|
||||||
const auto h_col = field["H"];
|
const auto h_col = field["H"];
|
||||||
const auto o_col = field["O"];
|
const auto o_col = field["O"];
|
||||||
|
|
||||||
@ -639,8 +583,7 @@ std::array<double, 2> getBaseTotals(Field &&field, int root, MPI_Comm comm)
|
|||||||
return base_totals;
|
return base_totals;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool getHasID(Field &&field, int root, MPI_Comm comm)
|
bool getHasID(Field &&field, int root, MPI_Comm comm) {
|
||||||
{
|
|
||||||
bool has_id;
|
bool has_id;
|
||||||
|
|
||||||
int rank;
|
int rank;
|
||||||
@ -648,8 +591,7 @@ bool getHasID(Field &&field, int root, MPI_Comm comm)
|
|||||||
|
|
||||||
const bool is_master = root == rank;
|
const bool is_master = root == rank;
|
||||||
|
|
||||||
if (is_master)
|
if (is_master) {
|
||||||
{
|
|
||||||
const auto ID_field = field["ID"];
|
const auto ID_field = field["ID"];
|
||||||
|
|
||||||
std::set<double> unique_IDs(ID_field.begin(), ID_field.end());
|
std::set<double> unique_IDs(ID_field.begin(), ID_field.end());
|
||||||
@ -666,8 +608,7 @@ bool getHasID(Field &&field, int root, MPI_Comm comm)
|
|||||||
return has_id;
|
return has_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char *argv[])
|
int main(int argc, char *argv[]) {
|
||||||
{
|
|
||||||
int world_size;
|
int world_size;
|
||||||
|
|
||||||
MPI_Init(&argc, &argv);
|
MPI_Init(&argc, &argv);
|
||||||
@ -678,8 +619,7 @@ int main(int argc, char *argv[])
|
|||||||
|
|
||||||
RInsidePOET &R = RInsidePOET::getInstance();
|
RInsidePOET &R = RInsidePOET::getInstance();
|
||||||
|
|
||||||
if (MY_RANK == 0)
|
if (MY_RANK == 0) {
|
||||||
{
|
|
||||||
MSG("Running POET version " + std::string(poet_version));
|
MSG("Running POET version " + std::string(poet_version));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -687,8 +627,7 @@ int main(int argc, char *argv[])
|
|||||||
|
|
||||||
RuntimeParameters run_params;
|
RuntimeParameters run_params;
|
||||||
|
|
||||||
if (parseInitValues(argc, argv, run_params) != 0)
|
if (parseInitValues(argc, argv, run_params) != 0) {
|
||||||
{
|
|
||||||
MPI_Finalize();
|
MPI_Finalize();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
@ -713,6 +652,7 @@ int main(int argc, char *argv[])
|
|||||||
|
|
||||||
ChemistryModule chemistry(run_params.work_package_size,
|
ChemistryModule chemistry(run_params.work_package_size,
|
||||||
init_list.getChemistryInit(), MPI_COMM_WORLD);
|
init_list.getChemistryInit(), MPI_COMM_WORLD);
|
||||||
|
ControlModule control(&run_params, &chemistry);
|
||||||
|
|
||||||
const ChemistryModule::SurrogateSetup surr_setup = {
|
const ChemistryModule::SurrogateSetup surr_setup = {
|
||||||
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
||||||
@ -730,12 +670,21 @@ int main(int argc, char *argv[])
|
|||||||
|
|
||||||
chemistry.masterEnableSurrogates(surr_setup);
|
chemistry.masterEnableSurrogates(surr_setup);
|
||||||
|
|
||||||
if (MY_RANK > 0)
|
const ControlModule::ControlSetup ctrl_setup = {
|
||||||
{
|
run_params.out_dir, // added
|
||||||
|
run_params.checkpoint_interval,
|
||||||
|
run_params.control_interval,
|
||||||
|
run_params.species_count,
|
||||||
|
run_params.species_names,
|
||||||
|
run_params.mape_threshold,
|
||||||
|
run_params.rrmse_threshold};
|
||||||
|
|
||||||
|
control.enableControlLogic(ctrl_setup);
|
||||||
|
|
||||||
|
|
||||||
|
if (MY_RANK > 0) {
|
||||||
chemistry.WorkerLoop();
|
chemistry.WorkerLoop();
|
||||||
}
|
} else {
|
||||||
else
|
|
||||||
{
|
|
||||||
// R.parseEvalQ("mysetup <- setup");
|
// R.parseEvalQ("mysetup <- setup");
|
||||||
// // if (MY_RANK == 0) { // get timestep vector from
|
// // if (MY_RANK == 0) { // get timestep vector from
|
||||||
// // grid_init function ... //
|
// // grid_init function ... //
|
||||||
@ -749,8 +698,7 @@ int main(int argc, char *argv[])
|
|||||||
R["out_ext"] = run_params.out_ext;
|
R["out_ext"] = run_params.out_ext;
|
||||||
R["out_dir"] = run_params.out_dir;
|
R["out_dir"] = run_params.out_dir;
|
||||||
|
|
||||||
if (run_params.use_ai_surrogate)
|
if (run_params.use_ai_surrogate) {
|
||||||
{
|
|
||||||
/* Incorporate ai surrogate from R */
|
/* Incorporate ai surrogate from R */
|
||||||
R.parseEvalQ(ai_surrogate_r_library);
|
R.parseEvalQ(ai_surrogate_r_library);
|
||||||
/* Use dht species for model input and output */
|
/* Use dht species for model input and output */
|
||||||
@ -799,8 +747,7 @@ int main(int argc, char *argv[])
|
|||||||
|
|
||||||
MPI_Finalize();
|
MPI_Finalize();
|
||||||
|
|
||||||
if (MY_RANK == 0)
|
if (MY_RANK == 0) {
|
||||||
{
|
|
||||||
MSG("done, bye!");
|
MSG("done, bye!");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -38,8 +38,7 @@ static const inline std::string ai_surrogate_r_library =
|
|||||||
R"(@R_AI_SURROGATE_LIB@)";
|
R"(@R_AI_SURROGATE_LIB@)";
|
||||||
static const inline std::string r_runtime_parameters = "mysetup";
|
static const inline std::string r_runtime_parameters = "mysetup";
|
||||||
|
|
||||||
struct RuntimeParameters
|
struct RuntimeParameters {
|
||||||
{
|
|
||||||
std::string out_dir;
|
std::string out_dir;
|
||||||
std::vector<double> timesteps;
|
std::vector<double> timesteps;
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user