From 71269166eac7874c5d9a7f373c29c4b2bfd51a72 Mon Sep 17 00:00:00 2001 From: rastogi Date: Sun, 19 Oct 2025 11:49:52 +0200 Subject: [PATCH] migrate: separate control logic from ChemistryModule into dedicated ControlModule --- src/CMakeLists.txt | 1 + src/Chemistry/ChemistryModule.hpp | 849 ++++++++++++++---------------- src/Chemistry/MasterFunctions.cpp | 128 ++--- src/Chemistry/WorkerFunctions.cpp | 2 +- src/Control/ControlModule.cpp | 131 +++++ src/Control/ControlModule.hpp | 110 ++++ src/IO/StatsIO.cpp | 2 +- src/IO/StatsIO.hpp | 4 +- src/poet.cpp | 227 +++----- src/poet.hpp.in | 3 +- 10 files changed, 771 insertions(+), 686 deletions(-) create mode 100644 src/Control/ControlModule.cpp create mode 100644 src/Control/ControlModule.hpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a9849a768..940848898 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -33,6 +33,7 @@ add_library(POETLib Chemistry/SurrogateModels/HashFunctions.cpp Chemistry/SurrogateModels/InterpolationModule.cpp Chemistry/SurrogateModels/ProximityHashTable.cpp + Control/ControlModule.cpp ) set(POET_TUG_APPROACH "Implicit" CACHE STRING "tug numerical approach to use") diff --git a/src/Chemistry/ChemistryModule.hpp b/src/Chemistry/ChemistryModule.hpp index c6f57bbec..73be52c60 100644 --- a/src/Chemistry/ChemistryModule.hpp +++ b/src/Chemistry/ChemistryModule.hpp @@ -4,17 +4,14 @@ #include "DataStructures/Field.hpp" #include "DataStructures/NamedVector.hpp" - #include "ChemistryDefs.hpp" - +#include "Control/ControlModule.hpp" #include "Init/InitialList.hpp" #include "NameDouble.h" #include "SurrogateModels/DHT_Wrapper.hpp" #include "SurrogateModels/Interpolation.hpp" - -#include "poet.hpp" - #include "PhreeqcRunner.hpp" + #include #include #include @@ -23,454 +20,412 @@ #include #include -namespace poet -{ +namespace poet { + class ControlModule; +/** + * \brief Wrapper around PhreeqcRM to provide POET specific parallelization with + * easy access. + */ +class ChemistryModule { +public: /** - * \brief Wrapper around PhreeqcRM to provide POET specific parallelization with - * easy access. + * Creates a new instance of Chemistry module with given grid cell count, work + * package size and communicator. + * + * This constructor shall only be called by the master. To create workers, see + * ChemistryModule::createWorker . + * + * When the use of parallelization is intended, the nxyz value shall be set to + * 1 to save memory and only one node is needed for initialization. + * + * \param nxyz Count of grid cells to allocate and initialize for each + * process. For parellel use set to 1 at the master. + * \param wp_size Count of grid cells to fill each work package at maximum. + * \param communicator MPI communicator to distribute work in. */ - class ChemistryModule - { - public: - /** - * Creates a new instance of Chemistry module with given grid cell count, work - * package size and communicator. - * - * This constructor shall only be called by the master. To create workers, see - * ChemistryModule::createWorker . - * - * When the use of parallelization is intended, the nxyz value shall be set to - * 1 to save memory and only one node is needed for initialization. - * - * \param nxyz Count of grid cells to allocate and initialize for each - * process. For parellel use set to 1 at the master. - * \param wp_size Count of grid cells to fill each work package at maximum. - * \param communicator MPI communicator to distribute work in. - */ - ChemistryModule(uint32_t wp_size, - const InitialList::ChemistryInit chem_params, - MPI_Comm communicator); - - /** - * Deconstructor, which frees DHT data structure if used. - */ - ~ChemistryModule(); - - void masterSetField(Field field); - /** - * Run the chemical simulation with parameters set. - */ - void simulate(double dt); - - /** - * Returns all known species names, including not only aqueous species, but - * also equilibrium, exchange, surface and kinetic reactants. - */ - // auto GetPropNames() const { return this->prop_names; } - - /** - * Return the accumulated runtime in seconds for chemical simulation. - */ - auto GetChemistryTime() const { return this->chem_t; } - - void setFilePadding(std::uint32_t maxiter) - { - this->file_pad = - static_cast(std::ceil(std::log10(maxiter + 1))); - } - - struct SurrogateSetup - { - std::vector prop_names; - std::array base_totals; - bool has_het_ids; - - bool dht_enabled; - std::uint32_t dht_size_mb; - int dht_snaps; - std::string dht_out_dir; - - bool interp_enabled; - std::uint32_t interp_bucket_size; - std::uint32_t interp_size_mb; - std::uint32_t interp_min_entries; - bool ai_surrogate_enabled; - }; - - void masterEnableSurrogates(const SurrogateSetup &setup) - { - // FIXME: This is a hack to get the prop_names and prop_count from the setup - this->prop_names = setup.prop_names; - this->prop_count = setup.prop_names.size(); - - this->dht_enabled = setup.dht_enabled; - this->interp_enabled = setup.interp_enabled; - this->ai_surrogate_enabled = setup.ai_surrogate_enabled; - - this->base_totals = setup.base_totals; - - if (this->dht_enabled || this->interp_enabled) - { - this->initializeDHT(setup.dht_size_mb, this->params.dht_species, - setup.has_het_ids); - - if (setup.dht_snaps != DHT_SNAPS_DISABLED) - { - this->setDHTSnapshots(setup.dht_snaps, setup.dht_out_dir); - } - } - - if (this->interp_enabled) - { - this->initializeInterp(setup.interp_bucket_size, setup.interp_size_mb, - setup.interp_min_entries, - this->params.interp_species); - } - } - - /** - * Intended to alias input parameters for grid initialization with a single - * value per species. - */ - using SingleCMap = std::unordered_map; - - /** - * Intended to alias input parameters for grid initialization with mutlitple - * values per species. - */ - using VectorCMap = std::unordered_map>; - - /** - * Enumerating DHT file options - */ - enum - { - DHT_SNAPS_DISABLED = 0, //!< disabled file output - DHT_SNAPS_SIMEND, //!< only output of snapshot after simulation - DHT_SNAPS_ITEREND //!< output snapshots after each iteration - }; - - /** - * **Only called by workers!** Start the worker listening loop. - */ - void WorkerLoop(); - - /** - * **Called by master** Advise the workers to break the loop. - */ - void MasterLoopBreak(); - - /** - * **Master only** Return count of grid cells. - */ - auto GetNCells() const { return this->n_cells; } - /** - * **Master only** Return work package size. - */ - auto GetWPSize() const { return this->wp_size; } - /** - * **Master only** Return the time in seconds the master spent waiting for any - * free worker. - */ - auto GetMasterIdleTime() const { return this->idle_t; } - /** - * **Master only** Return the time in seconds the master spent in sequential - * part of the simulation, including times for shuffling/unshuffling field - * etc. - */ - auto GetMasterSequentialTime() const { return this->seq_t; } - /** - * **Master only** Return the time in seconds the master spent in the - * send/receive loop. - */ - auto GetMasterLoopTime() const { return this->send_recv_t; } - - - auto GetMasterCtrlLogicTime() const { return this->ctrl_t; } - - auto GetMasterCtrlBcastTime() const { return this->bcast_ctrl_t; } - - auto GetMasterRecvCtrlLogicTime() const { return this->recv_ctrl_t; } - - /** - * **Master only** Collect and return all accumulated timings recorded by - * workers to run Phreeqc simulation. - * - * \return Vector of all accumulated Phreeqc timings. - */ - std::vector GetWorkerPhreeqcTimings() const; - /** - * **Master only** Collect and return all accumulated timings recorded by - * workers to get values from the DHT. - * - * \return Vector of all accumulated DHT get times. - */ - std::vector GetWorkerDHTGetTimings() const; - /** - * **Master only** Collect and return all accumulated timings recorded by - * workers to write values to the DHT. - * - * \return Vector of all accumulated DHT fill times. - */ - std::vector GetWorkerDHTFillTimings() const; - /** - * **Master only** Collect and return all accumulated timings recorded by - * workers waiting for work packages from the master. - * - * \return Vector of all accumulated waiting times. - */ - std::vector GetWorkerIdleTimings() const; - - std::vector GetWorkerControlTimings() const; - - /** - * **Master only** Collect and return DHT hits of all workers. - * - * \return Vector of all count of DHT hits. - */ - std::vector GetWorkerDHTHits() const; - - /** - * **Master only** Collect and return DHT evictions of all workers. - * - * \return Vector of all count of DHT evictions. - */ - std::vector GetWorkerDHTEvictions() const; - - /** - * **Master only** Returns the current state of the chemical field. - * - * \return Reference to the chemical field. - */ - Field &getField() { return this->chem_field; } - - /** - * **Master only** Enable/disable progress bar. - * - * \param enabled True if print progressbar, false if not. - */ - void setProgressBarPrintout(bool enabled) - { - this->print_progessbar = enabled; - }; - - /** - * **Master only** Set the ai surrogate validity vector from R - */ - void set_ai_surrogate_validity_vector(std::vector r_vector); - - std::vector GetWorkerInterpolationCalls() const; - - std::vector GetWorkerInterpolationWriteTimings() const; - std::vector GetWorkerInterpolationReadTimings() const; - std::vector GetWorkerInterpolationGatherTimings() const; - std::vector GetWorkerInterpolationFunctionCallTimings() const; - - std::vector GetWorkerPHTCacheHits() const; - - std::vector ai_surrogate_validity_vector; - - RuntimeParameters *runtime_params = nullptr; - - struct SimulationErrorStats - { - std::vector mape; - std::vector rrmse; - uint32_t iteration; // iterations in simulation after rollbacks - uint32_t rollback_count; - - SimulationErrorStats(size_t species_count, uint32_t iter, uint32_t counter) - : mape(species_count, 0.0), - rrmse(species_count, 0.0), - iteration(iter), - rollback_count(counter){} - }; - - std::vector error_history; - - static void computeSpeciesErrors(const std::vector &reference_values, - const std::vector &surrogate_values, - uint32_t size_per_prop, - uint32_t species_count, - SimulationErrorStats &species_error_stats); - - protected: - void initializeDHT(uint32_t size_mb, - const NamedVector &key_species, - bool has_het_ids); - void setDHTSnapshots(int type, const std::string &out_dir); - void setDHTReadFile(const std::string &input_file); - - void initializeInterp(std::uint32_t bucket_size, std::uint32_t size_mb, - std::uint32_t min_entries, - const NamedVector &key_species); - - enum - { - CHEM_FIELD_INIT, - CHEM_DHT_ENABLE, - CHEM_DHT_SIGNIF_VEC, - CHEM_DHT_SNAPS, - CHEM_DHT_READ_FILE, - CHEM_INTERP, - CHEM_IP_ENABLE, - CHEM_IP_MIN_ENTRIES, - CHEM_IP_SIGNIF_VEC, - CHEM_WORK_LOOP, - CHEM_PERF, - CHEM_BREAK_MAIN_LOOP, - CHEM_AI_BCAST_VALIDITY - }; - - enum - { - LOOP_WORK, - LOOP_END, - LOOP_CTRL - }; - - enum - { - WORKER_PHREEQC, - WORKER_CTRL_ITER, - WORKER_DHT_GET, - WORKER_DHT_FILL, - WORKER_IDLE, - WORKER_IP_WRITE, - WORKER_IP_READ, - WORKER_IP_GATHER, - WORKER_IP_FC, - WORKER_DHT_HITS, - WORKER_DHT_EVICTIONS, - WORKER_PHT_CACHE_HITS, - WORKER_IP_CALLS - }; - - std::vector interp_calls; - std::vector dht_hits; - std::vector dht_evictions; - - struct worker_s - { - double phreeqc_t = 0.; - double dht_get = 0.; - double dht_fill = 0.; - double idle_t = 0.; - double ctrl_t = 0.; - }; - - struct worker_info_s - { - char has_work = 0; - double *send_addr; - double *surrogate_addr; - }; - - using worker_list_t = std::vector; - using workpointer_t = std::vector::iterator; - - void MasterRunParallel(double dt); - void MasterRunSequential(); - - void MasterSendPkgs(worker_list_t &w_list, workpointer_t &work_pointer, workpointer_t &sur_pointer, - int &pkg_to_send, int &count_pkgs, int &free_workers, - double dt, uint32_t iteration, uint32_t control_iteration, - const std::vector &wp_sizes_vector); - void MasterRecvPkgs(worker_list_t &w_list, int &pkg_to_recv, bool to_send, - int &free_workers); - - std::vector MasterGatherWorkerTimings(int type) const; - std::vector MasterGatherWorkerMetrics(int type) const; - - void WorkerProcessPkgs(struct worker_s &timings, uint32_t &iteration); - - void WorkerDoWork(MPI_Status &probe_status, int double_count, - struct worker_s &timings); - void WorkerPostIter(MPI_Status &prope_status, uint32_t iteration); - void WorkerPostSim(uint32_t iteration); - - void WorkerWriteDHTDump(uint32_t iteration); - void WorkerReadDHTDump(const std::string &dht_input_file); - - void WorkerPerfToMaster(int type, const struct worker_s &timings); - void WorkerMetricsToMaster(int type); - - void WorkerRunWorkPackage(WorkPackage &work_package, double dSimTime, - double dTimestep); - - std::vector CalculateWPSizesVector(uint32_t n_cells, - uint32_t wp_size) const; - std::vector shuffleField(const std::vector &in_field, - uint32_t size_per_prop, uint32_t prop_count, - uint32_t wp_count); - void unshuffleField(const std::vector &in_buffer, - uint32_t size_per_prop, uint32_t prop_count, - uint32_t wp_count, std::vector &out_field); - std::vector - parseDHTSpeciesVec(const NamedVector &key_species, - const std::vector &to_compare) const; - - void BCastStringVec(std::vector &io); - - int comm_size, comm_rank; - MPI_Comm group_comm; - - bool is_sequential; - bool is_master; - - uint32_t wp_size; - bool dht_enabled{false}; - int dht_snaps_type{DHT_SNAPS_DISABLED}; - std::string dht_file_out_dir; - - poet::DHT_Wrapper *dht = nullptr; - - bool dht_fill_during_rollback{false}; - bool interp_enabled{false}; - std::unique_ptr interp; - - bool ai_surrogate_enabled{false}; - - static constexpr uint32_t BUFFER_OFFSET = 6; - - inline void ChemBCast(void *buf, int count, MPI_Datatype datatype) const - { - MPI_Bcast(buf, count, datatype, 0, this->group_comm); - } - - inline void PropagateFunctionType(int &type) const - { - ChemBCast(&type, 1, MPI_INT); - } - double simtime = 0.; - double idle_t = 0.; - double seq_t = 0.; - double send_recv_t = 0.; - - double ctrl_t = 0.; - double bcast_ctrl_t = 0.; - double recv_ctrl_t = 0.; - - std::array base_totals{0}; - - bool print_progessbar{false}; - - std::uint8_t file_pad{1}; - - double chem_t{0.}; - - uint32_t n_cells = 0; - uint32_t prop_count = 0; + ChemistryModule(uint32_t wp_size, + const InitialList::ChemistryInit chem_params, + MPI_Comm communicator); + + /** + * Deconstructor, which frees DHT data structure if used. + */ + ~ChemistryModule(); + + void masterSetField(Field field); + /** + * Run the chemical simulation with parameters set. + */ + void simulate(double dt); + + /** + * Returns all known species names, including not only aqueous species, but + * also equilibrium, exchange, surface and kinetic reactants. + */ + // auto GetPropNames() const { return this->prop_names; } + + /** + * Return the accumulated runtime in seconds for chemical simulation. + */ + auto GetChemistryTime() const { return this->chem_t; } + + void setFilePadding(std::uint32_t maxiter) { + this->file_pad = + static_cast(std::ceil(std::log10(maxiter + 1))); + } + + struct SurrogateSetup { std::vector prop_names; + std::array base_totals; + bool has_het_ids; - Field chem_field; + bool dht_enabled; + std::uint32_t dht_size_mb; + int dht_snaps; + std::string dht_out_dir; - const InitialList::ChemistryInit params; - - std::unique_ptr pqc_runner; - - std::vector sur_shuffled; + bool interp_enabled; + std::uint32_t interp_bucket_size; + std::uint32_t interp_size_mb; + std::uint32_t interp_min_entries; + bool ai_surrogate_enabled; }; + + void masterEnableSurrogates(const SurrogateSetup &setup) { + // FIXME: This is a hack to get the prop_names and prop_count from the setup + this->prop_names = setup.prop_names; + this->prop_count = setup.prop_names.size(); + + this->dht_enabled = setup.dht_enabled; + this->interp_enabled = setup.interp_enabled; + this->ai_surrogate_enabled = setup.ai_surrogate_enabled; + + this->base_totals = setup.base_totals; + + if (this->dht_enabled || this->interp_enabled) { + this->initializeDHT(setup.dht_size_mb, this->params.dht_species, + setup.has_het_ids); + + if (setup.dht_snaps != DHT_SNAPS_DISABLED) { + this->setDHTSnapshots(setup.dht_snaps, setup.dht_out_dir); + } + } + + if (this->interp_enabled) { + this->initializeInterp(setup.interp_bucket_size, setup.interp_size_mb, + setup.interp_min_entries, + this->params.interp_species); + } + } + + /** + * Intended to alias input parameters for grid initialization with a single + * value per species. + */ + using SingleCMap = std::unordered_map; + + /** + * Intended to alias input parameters for grid initialization with mutlitple + * values per species. + */ + using VectorCMap = std::unordered_map>; + + /** + * Enumerating DHT file options + */ + enum { + DHT_SNAPS_DISABLED = 0, //!< disabled file output + DHT_SNAPS_SIMEND, //!< only output of snapshot after simulation + DHT_SNAPS_ITEREND //!< output snapshots after each iteration + }; + + /** + * **Only called by workers!** Start the worker listening loop. + */ + void WorkerLoop(); + + /** + * **Called by master** Advise the workers to break the loop. + */ + void MasterLoopBreak(); + + /** + * **Master only** Return count of grid cells. + */ + auto GetNCells() const { return this->n_cells; } + /** + * **Master only** Return work package size. + */ + auto GetWPSize() const { return this->wp_size; } + /** + * **Master only** Return the time in seconds the master spent waiting for any + * free worker. + */ + auto GetMasterIdleTime() const { return this->idle_t; } + /** + * **Master only** Return the time in seconds the master spent in sequential + * part of the simulation, including times for shuffling/unshuffling field + * etc. + */ + auto GetMasterSequentialTime() const { return this->seq_t; } + /** + * **Master only** Return the time in seconds the master spent in the + * send/receive loop. + */ + auto GetMasterLoopTime() const { return this->send_recv_t; } + + auto GetMasterCtrlLogicTime() const { return this->ctrl_t; } + + auto GetMasterCtrlBcastTime() const { return this->bcast_ctrl_t; } + + auto GetMasterRecvCtrlLogicTime() const { return this->recv_ctrl_t; } + + /** + * **Master only** Collect and return all accumulated timings recorded by + * workers to run Phreeqc simulation. + * + * \return Vector of all accumulated Phreeqc timings. + */ + std::vector GetWorkerPhreeqcTimings() const; + /** + * **Master only** Collect and return all accumulated timings recorded by + * workers to get values from the DHT. + * + * \return Vector of all accumulated DHT get times. + */ + std::vector GetWorkerDHTGetTimings() const; + /** + * **Master only** Collect and return all accumulated timings recorded by + * workers to write values to the DHT. + * + * \return Vector of all accumulated DHT fill times. + */ + std::vector GetWorkerDHTFillTimings() const; + /** + * **Master only** Collect and return all accumulated timings recorded by + * workers waiting for work packages from the master. + * + * \return Vector of all accumulated waiting times. + */ + std::vector GetWorkerIdleTimings() const; + + std::vector GetWorkerControlTimings() const; + + /** + * **Master only** Collect and return DHT hits of all workers. + * + * \return Vector of all count of DHT hits. + */ + std::vector GetWorkerDHTHits() const; + + /** + * **Master only** Collect and return DHT evictions of all workers. + * + * \return Vector of all count of DHT evictions. + */ + std::vector GetWorkerDHTEvictions() const; + + /** + * **Master only** Returns the current state of the chemical field. + * + * \return Reference to the chemical field. + */ + Field &getField() { return this->chem_field; } + + /** + * **Master only** Enable/disable progress bar. + * + * \param enabled True if print progressbar, false if not. + */ + void setProgressBarPrintout(bool enabled) { + this->print_progessbar = enabled; + }; + + /** + * **Master only** Set the ai surrogate validity vector from R + */ + void set_ai_surrogate_validity_vector(std::vector r_vector); + + std::vector GetWorkerInterpolationCalls() const; + + std::vector GetWorkerInterpolationWriteTimings() const; + std::vector GetWorkerInterpolationReadTimings() const; + std::vector GetWorkerInterpolationGatherTimings() const; + std::vector GetWorkerInterpolationFunctionCallTimings() const; + + std::vector GetWorkerPHTCacheHits() const; + + std::vector ai_surrogate_validity_vector; + +protected: + void initializeDHT(uint32_t size_mb, + const NamedVector &key_species, + bool has_het_ids); + void setDHTSnapshots(int type, const std::string &out_dir); + void setDHTReadFile(const std::string &input_file); + + void initializeInterp(std::uint32_t bucket_size, std::uint32_t size_mb, + std::uint32_t min_entries, + const NamedVector &key_species); + + enum { + CHEM_FIELD_INIT, + CHEM_DHT_ENABLE, + CHEM_DHT_SIGNIF_VEC, + CHEM_DHT_SNAPS, + CHEM_DHT_READ_FILE, + CHEM_IP, // Control Flag + CHEM_IP_ENABLE, + CHEM_IP_MIN_ENTRIES, + CHEM_IP_SIGNIF_VEC, + CHEM_WORK_LOOP, + CHEM_PERF, + CHEM_BREAK_MAIN_LOOP, + CHEM_AI_BCAST_VALIDITY + }; + + enum { LOOP_WORK, LOOP_END, LOOP_CTRL }; + + enum { + WORKER_PHREEQC, + WORKER_CTRL_ITER, + WORKER_DHT_GET, + WORKER_DHT_FILL, + WORKER_IDLE, + WORKER_IP_WRITE, + WORKER_IP_READ, + WORKER_IP_GATHER, + WORKER_IP_FC, + WORKER_DHT_HITS, + WORKER_DHT_EVICTIONS, + WORKER_PHT_CACHE_HITS, + WORKER_IP_CALLS + }; + + std::vector interp_calls; + std::vector dht_hits; + std::vector dht_evictions; + + struct worker_s { + double phreeqc_t = 0.; + double dht_get = 0.; + double dht_fill = 0.; + double idle_t = 0.; + double ctrl_t = 0.; + }; + + struct worker_info_s { + char has_work = 0; + double *send_addr; + double *surrogate_addr; + }; + + using worker_list_t = std::vector; + using workpointer_t = std::vector::iterator; + + void MasterRunParallel(double dt); + void MasterRunSequential(); + + void MasterSendPkgs(worker_list_t &w_list, workpointer_t &work_pointer, + workpointer_t &sur_pointer, int &pkg_to_send, + int &count_pkgs, int &free_workers, double dt, + uint32_t iteration, uint32_t control_iteration, + const std::vector &wp_sizes_vector); + void MasterRecvPkgs(worker_list_t &w_list, int &pkg_to_recv, bool to_send, + int &free_workers); + + std::vector MasterGatherWorkerTimings(int type) const; + std::vector MasterGatherWorkerMetrics(int type) const; + + void WorkerProcessPkgs(struct worker_s &timings, uint32_t &iteration); + + void WorkerDoWork(MPI_Status &probe_status, int double_count, + struct worker_s &timings); + void WorkerPostIter(MPI_Status &prope_status, uint32_t iteration); + void WorkerPostSim(uint32_t iteration); + + void WorkerWriteDHTDump(uint32_t iteration); + void WorkerReadDHTDump(const std::string &dht_input_file); + + void WorkerPerfToMaster(int type, const struct worker_s &timings); + void WorkerMetricsToMaster(int type); + + void WorkerRunWorkPackage(WorkPackage &work_package, double dSimTime, + double dTimestep); + + std::vector CalculateWPSizesVector(uint32_t n_cells, + uint32_t wp_size) const; + std::vector shuffleField(const std::vector &in_field, + uint32_t size_per_prop, uint32_t prop_count, + uint32_t wp_count); + void unshuffleField(const std::vector &in_buffer, + uint32_t size_per_prop, uint32_t prop_count, + uint32_t wp_count, std::vector &out_field); + std::vector + parseDHTSpeciesVec(const NamedVector &key_species, + const std::vector &to_compare) const; + + void BCastStringVec(std::vector &io); + + int comm_size, comm_rank; + MPI_Comm group_comm; + + bool is_sequential; + bool is_master; + + uint32_t wp_size; + bool dht_enabled{false}; + int dht_snaps_type{DHT_SNAPS_DISABLED}; + std::string dht_file_out_dir; + + poet::DHT_Wrapper *dht = nullptr; + + bool dht_fill_during_rollback{false}; + bool interp_enabled{false}; + std::unique_ptr interp; + + bool ai_surrogate_enabled{false}; + + static constexpr uint32_t BUFFER_OFFSET = 6; + + inline void ChemBCast(void *buf, int count, MPI_Datatype datatype) const { + MPI_Bcast(buf, count, datatype, 0, this->group_comm); + } + + inline void PropagateFunctionType(int &type) const { + ChemBCast(&type, 1, MPI_INT); + } + double simtime = 0.; + double idle_t = 0.; + double seq_t = 0.; + double send_recv_t = 0.; + + double ctrl_t = 0.; + double bcast_ctrl_t = 0.; + double recv_ctrl_t = 0.; + + std::array base_totals{0}; + + bool print_progessbar{false}; + + std::uint8_t file_pad{1}; + + double chem_t{0.}; + + uint32_t n_cells = 0; + uint32_t prop_count = 0; + std::vector prop_names; + + Field chem_field; + + const InitialList::ChemistryInit params; + + std::unique_ptr pqc_runner; + + std::unique_ptr ctrl_module; + + //std::vector sur_shuffled; +}; } // namespace poet #endif // CHEMISTRYMODULE_H_ diff --git a/src/Chemistry/MasterFunctions.cpp b/src/Chemistry/MasterFunctions.cpp index c2710bf8b..4c75fb3cd 100644 --- a/src/Chemistry/MasterFunctions.cpp +++ b/src/Chemistry/MasterFunctions.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include @@ -166,39 +165,6 @@ std::vector poet::ChemistryModule::GetWorkerPHTCacheHits() const { return ret; } -void poet::ChemistryModule::computeSpeciesErrors(const std::vector &reference_values, - const std::vector &surrogate_values, - uint32_t size_per_prop, - uint32_t species_count, - SimulationErrorStats &species_error_stats) { - for (uint32_t i = 0; i < species_count; ++i) { - double err_sum = 0.0; - double sqr_err_sum = 0.0; - uint32_t base_idx = i * size_per_prop; - - for (uint32_t j = 0; j < size_per_prop; ++j) { - const double ref_value = reference_values[base_idx + j]; - const double sur_value = surrogate_values[base_idx + j]; - - if (ref_value == 0.0) { - if (sur_value != 0.0) { - err_sum += 1.0; - sqr_err_sum += 1.0; - } - // Both zero: skip - } else { - double alpha = 1.0 - (sur_value / ref_value); - err_sum += std::abs(alpha); - sqr_err_sum += alpha * alpha; - } - } - - species_error_stats.mape[i] = 100.0 * (err_sum / size_per_prop); - species_error_stats.rrmse[i] = - (size_per_prop > 0) ? std::sqrt(sqr_err_sum / size_per_prop) : 0.0; - } -} - inline std::vector shuffleVector(const std::vector &in_vector, uint32_t size_per_prop, uint32_t wp_count) { @@ -269,8 +235,8 @@ inline void printProgressbar(int count_pkgs, int n_wp, int barWidth = 70) { inline void poet::ChemistryModule::MasterSendPkgs( worker_list_t &w_list, workpointer_t &work_pointer, workpointer_t &sur_pointer, int &pkg_to_send, int &count_pkgs, - int &free_workers, double dt, uint32_t iteration, - uint32_t control_interval, const std::vector &wp_sizes_vector) { + int &free_workers, double dt, uint32_t iteration, uint32_t control_interval, + const std::vector &wp_sizes_vector) { /* declare variables */ int local_work_package_size; @@ -335,7 +301,7 @@ inline void poet::ChemistryModule::MasterRecvPkgs(worker_list_t &w_list, int need_to_receive = 1; double idle_a, idle_b; int p, size; - double recv_a, recv_b; + double recv_a, recv_b; MPI_Status probe_status; // master_recv_a = MPI_Wtime(); @@ -461,28 +427,9 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { /* start time measurement of broadcasting interpolation status */ ctrl_bcast_a = MPI_Wtime(); - - ftype = CHEM_INTERP; + ftype = CHEM_IP; PropagateFunctionType(ftype); - - int interp_flag = 0; - int dht_fill_flag = 0; - - if(this->runtime_params->rollback_enabled){ - this->interp_enabled = false; - this->dht_fill_during_rollback = true; - interp_flag = 0; - dht_fill_flag = 1; - } - else { - this->interp_enabled = true; - this->dht_fill_during_rollback = false; - interp_flag = 1; - dht_fill_flag = 0; - } - ChemBCast(&interp_flag, 1, MPI_INT); - ChemBCast(&dht_fill_flag, 1, MPI_INT); - + ctrl_module->BCastControlFlags(); /* end time measurement of broadcasting interpolation status */ ctrl_bcast_b = MPI_Wtime(); this->bcast_ctrl_t += ctrl_bcast_b - ctrl_bcast_a; @@ -494,11 +441,12 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { static uint32_t iteration = 0; - uint32_t control_logic_enabled = this->runtime_params->control_interval_enabled ? 1 : 0; + uint32_t control_logic_enabled = + ctrl_module->control_interval_enabled ? 1 : 0; if (control_logic_enabled) { - sur_shuffled.clear(); - sur_shuffled.reserve(this->n_cells * this->prop_count); + ctrl_module->sur_shuffled.clear(); + ctrl_module->sur_shuffled.reserve(this->n_cells * this->prop_count); } /* start time measurement of sequential part */ @@ -511,14 +459,14 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count, wp_sizes_vector.size()); - this->sur_shuffled.resize(mpi_buffer.size()); + ctrl_module->sur_shuffled.resize(mpi_buffer.size()); /* setup local variables */ pkg_to_send = wp_sizes_vector.size(); pkg_to_recv = wp_sizes_vector.size(); workpointer_t work_pointer = mpi_buffer.begin(); - workpointer_t sur_pointer = sur_shuffled.begin(); + workpointer_t sur_pointer = ctrl_module->sur_shuffled.begin(); worker_list_t worker_list(this->comm_size - 1); free_workers = this->comm_size - 1; @@ -552,43 +500,37 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { // Just to complete the progressbar std::cout << std::endl; - /* stop time measurement of chemistry time needed for send/recv loop */ - worker_chemistry_b = MPI_Wtime(); - this->send_recv_t += worker_chemistry_b - worker_chemistry_a; + /* stop time measurement of chemistry time needed for send/recv loop */ + worker_chemistry_b = MPI_Wtime(); + this->send_recv_t += worker_chemistry_b - worker_chemistry_a; - /* start time measurement of sequential part */ - seq_c = MPI_Wtime(); + /* start time measurement of sequential part */ + seq_c = MPI_Wtime(); - /* unshuffle grid */ - // grid.importAndUnshuffle(mpi_buffer); - std::vector out_vec{mpi_buffer}; - unshuffleField(mpi_buffer, this->n_cells, this->prop_count, - wp_sizes_vector.size(), out_vec); - chem_field = out_vec; + /* unshuffle grid */ + // grid.importAndUnshuffle(mpi_buffer); + std::vector out_vec{mpi_buffer}; + unshuffleField(mpi_buffer, this->n_cells, this->prop_count, + wp_sizes_vector.size(), out_vec); + chem_field = out_vec; - /* do master stuff */ + /* do master stuff */ - /* start time measurement of control logic */ - ctrl_a = MPI_Wtime(); + /* start time measurement of control logic */ + ctrl_a = MPI_Wtime(); - if (control_logic_enabled && !this->runtime_params->rollback_enabled) { + if (control_logic_enabled && !ctrl_module->rollback_enabled) { + std::cout << "[Master] Control logic enabled for this iteration." << std::endl; + std::vector sur_unshuffled{ctrl_module->sur_shuffled}; + unshuffleField(ctrl_module->sur_shuffled, this->n_cells, this->prop_count, + wp_sizes_vector.size(), sur_unshuffled); - std::vector sur_unshuffled{sur_shuffled};; - - unshuffleField(sur_shuffled, this->n_cells, this->prop_count, - wp_sizes_vector.size(), sur_unshuffled); - - SimulationErrorStats stats(this->prop_count, this->runtime_params->global_iter, this->runtime_params->rollback_counter); - - computeSpeciesErrors(out_vec, sur_unshuffled, this->n_cells, this->prop_count, stats); - - error_history.push_back(stats); - } - - /* end time measurement of control logic */ - ctrl_b = MPI_Wtime(); - this->ctrl_t += ctrl_b - ctrl_a; + ctrl_module->computeSpeciesErrors(out_vec, sur_unshuffled, this->n_cells); + } + /* end time measurement of control logic */ + ctrl_b = MPI_Wtime(); + this->ctrl_t += ctrl_b - ctrl_a; /* start time measurement of master chemistry */ sim_e_chemistry = MPI_Wtime(); diff --git a/src/Chemistry/WorkerFunctions.cpp b/src/Chemistry/WorkerFunctions.cpp index 4406ec65d..8cf15fe92 100644 --- a/src/Chemistry/WorkerFunctions.cpp +++ b/src/Chemistry/WorkerFunctions.cpp @@ -67,7 +67,7 @@ namespace poet MPI_INT, 0, this->group_comm); break; } - case CHEM_INTERP: + case CHEM_IP: { int interp_flag = 0; int dht_fill_flag = 0; diff --git a/src/Control/ControlModule.cpp b/src/Control/ControlModule.cpp new file mode 100644 index 000000000..a5a71d577 --- /dev/null +++ b/src/Control/ControlModule.cpp @@ -0,0 +1,131 @@ +#include "ControlModule.hpp" +#include "IO/Datatypes.hpp" +#include "IO/HDF5Functions.hpp" +#include "IO/StatsIO.hpp" +#include + +bool poet::ControlModule::isControlIteration(uint32_t iter) { + control_interval_enabled = (iter % control_interval == 0); + if (control_interval_enabled) { + MSG("[Control] Control interval triggered at iteration " + + std::to_string(iter)); + } + return control_interval_enabled; +} + +void poet::ControlModule::beginIteration() { + if (rollback_enabled) { + if (sur_disabled_counter > 0) { + sur_disabled_counter--; + MSG("Rollback counter: " + std::to_string(sur_disabled_counter)); + } else { + rollback_enabled = false; + } + } +} + +void poet::ControlModule::endIteration(uint32_t iter) { + /* Writing a checkpointing */ + if (checkpoint_interval > 0 && iter % checkpoint_interval == 0) { + MSG("Writing checkpoint of iteration " + std::to_string(iter)); + write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5", + {.field = chem->getField(), .iteration = iter}); + } + + /* Control Logic*/ + if (control_interval_enabled && !rollback_enabled) { + writeStatsToCSV(error_history, species_names, out_dir, + "stats_overview"); + + if (triggerRollbackIfExceeded(*chem, *params, iter)) { + rollback_enabled = true; + rollback_counter++; + sur_disabled_counter = control_interval; + MSG("Interpolation disabled for the next " + + std::to_string(control_interval) + "."); + } + } +} + +void poet::ControlModule::BCastControlFlags() { + int interp_flag = rollback_enabled ? 0 : 1; + int dht_fill_flag = rollback_enabled ? 1 : 0; + chem->ChemBCast(&interp_flag, 1, MPI_INT); + chem->ChemBCast(&dht_fill_flag, 1, MPI_INT); +} + +bool poet::ControlModule::triggerRollbackIfExceeded(ChemistryModule &chem, + RuntimeParameters ¶ms, + uint32_t &iter) { + + if (error_history.empty()) { + MSG("No error history yet; skipping rollback check."); + return false; + } + + const auto &mape = chem.error_history.back().mape; + const auto &props = chem.getField().GetProps(); + + for (uint32_t i = 0; i < params.mape_threshold.size(); ++i) { + // Skip invalid entries + if (mape[i] == 0) { + continue; + } + bool mape_exceeded = mape[i] > params.mape_threshold[i]; + + if (mape_exceeded) { + uint32_t rollback_iter = ((iter - 1) / params.checkpoint_interval) * + params.checkpoint_interval; + + MSG("[THRESHOLD EXCEEDED] " + props[i] + + " has MAPE = " + std::to_string(mape[i]) + + " exceeding threshold = " + std::to_string(params.mape_threshold[i]) + + " → rolling back to iteration " + std::to_string(rollback_iter)); + + Checkpoint_s checkpoint_read{.field = chem.getField()}; + read_checkpoint(params.out_dir, + "checkpoint" + std::to_string(rollback_iter) + ".hdf5", + checkpoint_read); + iter = checkpoint_read.iteration; + return true; + } + } + MSG("All species are within their MAPE and RRMSE thresholds."); + return false; +} + +void poet::ControlModule::computeSpeciesErrors( + const std::vector &reference_values, + const std::vector &surrogate_values, uint32_t size_per_prop) { + + SimulationErrorStats species_error_stats(species_count, params->global_iter, + rollback_counter); + + for (uint32_t i = 0; i < species_count; ++i) { + double err_sum = 0.0; + double sqr_err_sum = 0.0; + uint32_t base_idx = i * size_per_prop; + + for (uint32_t j = 0; j < size_per_prop; ++j) { + const double ref_value = reference_values[base_idx + j]; + const double sur_value = surrogate_values[base_idx + j]; + + if (ref_value == 0.0) { + if (sur_value != 0.0) { + err_sum += 1.0; + sqr_err_sum += 1.0; + } + // Both zero: skip + } else { + double alpha = 1.0 - (sur_value / ref_value); + err_sum += std::abs(alpha); + sqr_err_sum += alpha * alpha; + } + } + + species_error_stats.mape[i] = 100.0 * (err_sum / size_per_prop); + species_error_stats.rrmse[i] = + (size_per_prop > 0) ? std::sqrt(sqr_err_sum / size_per_prop) : 0.0; + } + error_history.push_back(species_error_stats); +} diff --git a/src/Control/ControlModule.hpp b/src/Control/ControlModule.hpp new file mode 100644 index 000000000..6bd848d07 --- /dev/null +++ b/src/Control/ControlModule.hpp @@ -0,0 +1,110 @@ +#ifndef CONTROLMODULE_H_ +#define CONTROLMODULE_H_ + +#include "Base/Macros.hpp" +#include "Chemistry/ChemistryModule.hpp" +#include "poet.hpp" + +#include +#include +#include + +namespace poet { + +class ChemistryModule; + +class ControlModule { + +public: + ControlModule(RuntimeParameters *run_params, ChemistryModule *chem_module) + : params(run_params), chem(chem_module) {}; + + /* Control configuration*/ + std::vector species_names; + uint32_t species_count = 0; + std::string out_dir; + + bool rollback_enabled = false; + bool control_interval_enabled = false; + + std::uint32_t global_iter = 0; + std::uint32_t sur_disabled_counter = 0; + std::uint32_t rollback_counter = 0; + std::uint32_t checkpoint_interval = 0; + std::uint32_t control_interval = 0; + + std::vector mape_threshold; + std::vector rrmse_threshold; + + double ctrl_t = 0.; + double bcast_ctrl_t = 0.; + double recv_ctrl_t = 0.; + + /* Buffer for shuffled surrogate data */ + std::vector sur_shuffled; + + bool isControlIteration(uint32_t iter); + + void beginIteration(); + + void endIteration(uint32_t iter); + + void BCastControlFlags(); + + bool triggerRollbackIfExceeded(ChemistryModule &chem, + RuntimeParameters ¶ms, uint32_t &iter); + + struct SimulationErrorStats { + std::vector mape; + std::vector rrmse; + uint32_t iteration; // iterations in simulation after rollbacks + uint32_t rollback_count; + + SimulationErrorStats(size_t species_count, uint32_t iter, uint32_t counter) + : mape(species_count, 0.0), rrmse(species_count, 0.0), iteration(iter), + rollback_count(counter) {} + }; + + static void computeSpeciesErrors(const std::vector &reference_values, + const std::vector &surrogate_values, + uint32_t size_per_prop); + + std::vector error_history; + + struct ControlSetup { + std::string out_dir; + std::uint32_t checkpoint_interval; + std::uint32_t control_interval; + std::uint32_t species_count; + + std::vector species_names; + std::vector mape_threshold; + std::vector rrmse_threshold; + }; + + void enableControlLogic(const ControlSetup &setup) { + out_dir = setup.out_dir; + checkpoint_interval = setup.checkpoint_interval; + control_interval = setup.control_interval; + species_count = setup.species_count; + + species_names = setup.species_names; + mape_threshold = setup.mape_threshold; + rrmse_threshold = setup.rrmse_threshold; + } + + /* Profiling getters */ + auto GetMasterCtrlLogicTime() const { return this->ctrl_t; } + + auto GetMasterCtrlBcastTime() const { return this->bcast_ctrl_t; } + + auto GetMasterRecvCtrlLogicTime() const { return this->recv_ctrl_t; } + +private: + RuntimeParameters *params; + ChemistryModule *chem; +}; + +} // namespace poet + +#endif // CONTROLMODULE_H_ \ No newline at end of file diff --git a/src/IO/StatsIO.cpp b/src/IO/StatsIO.cpp index 8e3c2978c..0b82a191a 100644 --- a/src/IO/StatsIO.cpp +++ b/src/IO/StatsIO.cpp @@ -7,7 +7,7 @@ namespace poet { - void writeStatsToCSV(const std::vector &all_stats, + void writeStatsToCSV(const std::vector &all_stats, const std::vector &species_names, const std::string &out_dir, const std::string &filename) diff --git a/src/IO/StatsIO.hpp b/src/IO/StatsIO.hpp index cb432f939..e208d4bbb 100644 --- a/src/IO/StatsIO.hpp +++ b/src/IO/StatsIO.hpp @@ -1,9 +1,9 @@ #include -#include "Chemistry/ChemistryModule.hpp" +#include "Control/ControlModule.hpp" namespace poet { - void writeStatsToCSV(const std::vector &all_stats, + void writeStatsToCSV(const std::vector &all_stats, const std::vector &species_names, const std::string &out_dir, const std::string &filename); diff --git a/src/poet.cpp b/src/poet.cpp index 4b920aa02..48260f3c7 100644 --- a/src/poet.cpp +++ b/src/poet.cpp @@ -25,10 +25,8 @@ #include "Base/RInsidePOET.hpp" #include "CLI/CLI.hpp" #include "Chemistry/ChemistryModule.hpp" +#include "Control/ControlManager.hpp" #include "DataStructures/Field.hpp" -#include "IO/Datatypes.hpp" -#include "IO/HDF5Functions.hpp" -#include "IO/StatsIO.hpp" #include "Init/InitialList.hpp" #include "Transport/DiffusionModule.hpp" @@ -68,8 +66,7 @@ static poet::DEFunc ReadRObj_R; static poet::DEFunc SaveRObj_R; static poet::DEFunc source_R; -static void init_global_functions(RInside &R) -{ +static void init_global_functions(RInside &R) { R.parseEval(kin_r_library); master_init_R = DEFunc("master_init"); master_iteration_end_R = DEFunc("master_iteration_end"); @@ -92,15 +89,9 @@ static void init_global_functions(RInside &R) // R.parseEval("mysetup$state_C <- TMP"); // } -enum ParseRet -{ - PARSER_OK, - PARSER_ERROR, - PARSER_HELP -}; +enum ParseRet { PARSER_OK, PARSER_ERROR, PARSER_HELP }; -int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) -{ +int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) { CLI::App app{"POET - Potsdam rEactive Transport simulator"}; @@ -182,12 +173,9 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) "Output directory of the simulation") ->required(); - try - { + try { app.parse(argc, argv); - } - catch (const CLI::ParseError &e) - { + } catch (const CLI::ParseError &e) { app.exit(e); return -1; } @@ -199,16 +187,14 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) if (params.as_qs) params.out_ext = "qs"; - if (MY_RANK == 0) - { + if (MY_RANK == 0) { // MSG("Complete results storage is " + BOOL_PRINT(simparams.store_result)); MSG("Output format/extension is " + params.out_ext); MSG("Work Package Size: " + std::to_string(params.work_package_size)); MSG("DHT is " + BOOL_PRINT(params.use_dht)); MSG("AI Surrogate is " + BOOL_PRINT(params.use_ai_surrogate)); - if (params.use_dht) - { + if (params.use_dht) { // MSG("DHT strategy is " + std::to_string(simparams.dht_strategy)); // MDL: these should be outdated (?) // MSG("DHT key default digits (ignored if 'signif_vector' is " @@ -222,8 +208,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) // MSG("DHT load file is " + chem_params.dht_file); } - if (params.use_interp) - { + if (params.use_interp) { MSG("PHT interpolation enabled: " + BOOL_PRINT(params.use_interp)); MSG("PHT interp-size = " + std::to_string(params.interp_size)); MSG("PHT interp-min = " + std::to_string(params.interp_min_entries)); @@ -251,8 +236,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) // // log before rounding? // R["dht_log"] = simparams.dht_log; - try - { + try { Rcpp::List init_params_(ReadRObj_R(init_file)); params.init_params = init_params_; @@ -269,13 +253,11 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) Rcpp::as(global_rt_setup->operator[]("control_interval")); params.checkpoint_interval = Rcpp::as(global_rt_setup->operator[]("checkpoint_interval")); - params.mape_threshold = - Rcpp::as>(global_rt_setup->operator[]("mape_threshold")); - params.rrmse_threshold = - Rcpp::as>(global_rt_setup->operator[]("rrmse_threshold")); - } - catch (const std::exception &e) - { + params.mape_threshold = Rcpp::as>( + global_rt_setup->operator[]("mape_threshold")); + params.rrmse_threshold = Rcpp::as>( + global_rt_setup->operator[]("rrmse_threshold")); + } catch (const std::exception &e) { ERRMSG("Error while parsing R scripts: " + std::string(e.what())); return ParseRet::PARSER_ERROR; } @@ -285,8 +267,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) // HACK: this is a step back as the order and also the count of fields is // predefined, but it will change in the future -void call_master_iter_end(RInside &R, const Field &trans, const Field &chem) -{ +void call_master_iter_end(RInside &R, const Field &trans, const Field &chem) { R["TMP"] = Rcpp::wrap(trans.AsVector()); R["TMP_PROPS"] = Rcpp::wrap(trans.GetProps()); R.parseEval(std::string("state_T <- setNames(data.frame(matrix(TMP, nrow=" + @@ -303,53 +284,15 @@ void call_master_iter_end(RInside &R, const Field &trans, const Field &chem) *global_rt_setup = R["setup"]; } -bool triggerRollbackIfExceeded(ChemistryModule &chem, RuntimeParameters ¶ms, uint32_t ¤t_iteration) -{ - const auto &mape = chem.error_history.back().mape; - const auto &rrmse = chem.error_history.back().rrmse; - const auto &props = chem.getField().GetProps(); - - for (uint32_t i = 0; i < params.mape_threshold.size(); ++i) - { - // Skip invalid entries - if ((mape[i] == 0 && rrmse[i] == 0)) - continue; - - bool mape_exceeded = mape[i] > params.mape_threshold[i]; - bool rrmse_exceeded = rrmse[i] > params.rrmse_threshold[i]; - - if (mape_exceeded || rrmse_exceeded) - { - uint32_t rollback_iter = ((current_iteration - 1) / params.checkpoint_interval) * params.checkpoint_interval; - std::string metric = mape_exceeded ? "MAPE" : "RRMSE"; - double value = mape_exceeded ? mape[i] : rrmse[i]; - double threshold = mape_exceeded ? params.mape_threshold[i] : params.rrmse_threshold[i]; - - MSG("[THRESHOLD EXCEEDED] " + props[i] + " has " + metric + " = " + - std::to_string(value) + " exceeding threshold = " + std::to_string(threshold) + - " → rolling back to iteration " + std::to_string(rollback_iter)); - - Checkpoint_s checkpoint_read{.field = chem.getField()}; - read_checkpoint(params.out_dir, "checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read); - current_iteration = checkpoint_read.iteration; - return true; // rollback happened - } - } - MSG("All species are within their MAPE and RRMSE thresholds."); - return false; -} - static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, DiffusionModule &diffusion, - ChemistryModule &chem) -{ + ChemistryModule &chem, ControlModule &control) { /* Iteration Count is dynamic, retrieving value from R (is only needed by * master for the following loop) */ uint32_t maxiter = params.timesteps.size(); - if (params.print_progress) - { + if (params.print_progress) { chem.setProgressBarPrintout(true); } R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps()); @@ -359,20 +302,24 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, double dSimTime{0}; double chkTime = 0.0; - for (uint32_t iter = 1; iter < maxiter + 1; iter++) - { + for (uint32_t iter = 1; iter < maxiter + 1; iter++) { // Rollback countdowm - if (params.rollback_enabled) { - if (params.sur_disabled_counter > 0) { + + /* + if (params.rollback_enabled) { + if (params.sur_disabled_counter > 0) { --params.sur_disabled_counter; MSG("Rollback counter: " + std::to_string(params.sur_disabled_counter)); - } else { + } else { params.rollback_enabled = false; } } + */ + control.beginIteration(iter); - params.global_iter = iter; - params.control_interval_enabled = (iter % params.control_interval == 0); + // params.global_iter = iter; + control.isControlIteration(iter); + // params.control_interval_enabled = (iter % params.control_interval == 0); double start_t = MPI_Wtime(); @@ -389,13 +336,12 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, /* run transport */ diffusion.simulate(dt); - chem.runtime_params = ¶ms; + // chem.runtime_params = ¶ms; chem.getField().update(diffusion.getField()); // MSG("Chemistry start"); - if (params.use_ai_surrogate) - { + if (params.use_ai_surrogate) { double ai_start_t = MPI_Wtime(); // Save current values from the tug field as predictor for the ai step R["TMP"] = Rcpp::wrap(chem.getField().AsVector()); @@ -446,8 +392,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, chem.simulate(dt); /* AI surrogate iterative training*/ - if (params.use_ai_surrogate) - { + if (params.use_ai_surrogate) { double ai_start_t = MPI_Wtime(); R["TMP"] = Rcpp::wrap(chem.getField().AsVector()); @@ -487,25 +432,32 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, std::to_string(maxiter)); double chk_start = MPI_Wtime(); + control.endIteration(iter) + /* + if (iter % params.checkpoint_interval == 0) { + MSG("Writing checkpoint of iteration " + std::to_string(iter)); + write_checkpoint(params.out_dir, + "checkpoint" + std::to_string(iter) + ".hdf5", + {.field = chem.getField(), .iteration = iter}); + } - if(iter % params.checkpoint_interval == 0){ - MSG("Writing checkpoint of iteration " + std::to_string(iter)); - write_checkpoint(params.out_dir, "checkpoint" + std::to_string(iter) + ".hdf5", - {.field = chem.getField(), .iteration = iter}); - } - if (params.control_interval_enabled && !params.rollback_enabled) - { - writeStatsToCSV(chem.error_history, chem.getField().GetProps(), params.out_dir,"stats_overview"); + if (params.control_interval_enabled && !params.rollback_enabled) { + writeStatsToCSV(chem.error_history, chem.getField().GetProps(), + params.out_dir, "stats_overview"); - if(triggerRollbackIfExceeded(chem, params, iter)){ - params.rollback_enabled = true; - params.rollback_counter ++; - params.sur_disabled_counter = params.control_interval; - MSG("Interpolation disabled for the next " + std::to_string(params.control_interval) + "."); - } - } - double chk_end = MPI_Wtime(); + if (triggerRollbackIfExceeded(chem, params, iter)) { + params.rollback_enabled = true; + params.rollback_counter++; + params.sur_disabled_counter = params.control_interval; + MSG("Interpolation disabled for the next " + + std::to_string(params.control_interval) + "."); + } + } + + */ + + double chk_end = MPI_Wtime(); chkTime += chk_end - chk_start; // MSG(); @@ -529,10 +481,10 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, ctrl_profiling["ctrl_logic_master"] = chem.GetMasterCtrlLogicTime(); ctrl_profiling["bcast_ctrl_logic_master"] = chem.GetMasterCtrlBcastTime(); ctrl_profiling["recv_ctrl_logic_maser"] = chem.GetMasterRecvCtrlLogicTime(); - ctrl_profiling["ctrl_logic_worker"] = Rcpp::wrap(chem.GetWorkerControlTimings()); + ctrl_profiling["ctrl_logic_worker"] = + Rcpp::wrap(chem.GetWorkerControlTimings()); - if (params.use_dht) - { + if (params.use_dht) { chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits()); chem_profiling["dht_evictions"] = Rcpp::wrap(chem.GetWorkerDHTEvictions()); chem_profiling["dht_get_time"] = Rcpp::wrap(chem.GetWorkerDHTGetTimings()); @@ -540,8 +492,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, Rcpp::wrap(chem.GetWorkerDHTFillTimings()); } - if (params.use_interp) - { + if (params.use_interp) { chem_profiling["interp_w"] = Rcpp::wrap(chem.GetWorkerInterpolationWriteTimings()); chem_profiling["interp_r"] = @@ -561,15 +512,13 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, profiling["diffusion"] = diffusion_profiling; profiling["ctrl_logic"] = ctrl_profiling; - chem.MasterLoopBreak(); return profiling; } std::vector getSpeciesNames(const Field &&field, int root, - MPI_Comm comm) -{ + MPI_Comm comm) { std::uint32_t n_elements; std::uint32_t n_string_size; @@ -579,13 +528,11 @@ std::vector getSpeciesNames(const Field &&field, int root, const bool is_master = root == rank; // first, the master sends all the species names iterative - if (is_master) - { + if (is_master) { n_elements = field.GetProps().size(); MPI_Bcast(&n_elements, 1, MPI_UINT32_T, root, MPI_COMM_WORLD); - for (std::uint32_t i = 0; i < n_elements; i++) - { + for (std::uint32_t i = 0; i < n_elements; i++) { n_string_size = field.GetProps()[i].size(); MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD); MPI_Bcast(const_cast(field.GetProps()[i].c_str()), n_string_size, @@ -600,8 +547,7 @@ std::vector getSpeciesNames(const Field &&field, int root, std::vector species_names_out(n_elements); - for (std::uint32_t i = 0; i < n_elements; i++) - { + for (std::uint32_t i = 0; i < n_elements; i++) { MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD); char recv_buf[n_string_size]; @@ -614,8 +560,7 @@ std::vector getSpeciesNames(const Field &&field, int root, return species_names_out; } -std::array getBaseTotals(Field &&field, int root, MPI_Comm comm) -{ +std::array getBaseTotals(Field &&field, int root, MPI_Comm comm) { std::array base_totals; int rank; @@ -623,8 +568,7 @@ std::array getBaseTotals(Field &&field, int root, MPI_Comm comm) const bool is_master = root == rank; - if (is_master) - { + if (is_master) { const auto h_col = field["H"]; const auto o_col = field["O"]; @@ -639,8 +583,7 @@ std::array getBaseTotals(Field &&field, int root, MPI_Comm comm) return base_totals; } -bool getHasID(Field &&field, int root, MPI_Comm comm) -{ +bool getHasID(Field &&field, int root, MPI_Comm comm) { bool has_id; int rank; @@ -648,8 +591,7 @@ bool getHasID(Field &&field, int root, MPI_Comm comm) const bool is_master = root == rank; - if (is_master) - { + if (is_master) { const auto ID_field = field["ID"]; std::set unique_IDs(ID_field.begin(), ID_field.end()); @@ -666,8 +608,7 @@ bool getHasID(Field &&field, int root, MPI_Comm comm) return has_id; } -int main(int argc, char *argv[]) -{ +int main(int argc, char *argv[]) { int world_size; MPI_Init(&argc, &argv); @@ -678,8 +619,7 @@ int main(int argc, char *argv[]) RInsidePOET &R = RInsidePOET::getInstance(); - if (MY_RANK == 0) - { + if (MY_RANK == 0) { MSG("Running POET version " + std::string(poet_version)); } @@ -687,8 +627,7 @@ int main(int argc, char *argv[]) RuntimeParameters run_params; - if (parseInitValues(argc, argv, run_params) != 0) - { + if (parseInitValues(argc, argv, run_params) != 0) { MPI_Finalize(); return 0; } @@ -713,6 +652,7 @@ int main(int argc, char *argv[]) ChemistryModule chemistry(run_params.work_package_size, init_list.getChemistryInit(), MPI_COMM_WORLD); + ControlModule control(&run_params, &chemistry); const ChemistryModule::SurrogateSetup surr_setup = { getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), @@ -730,12 +670,21 @@ int main(int argc, char *argv[]) chemistry.masterEnableSurrogates(surr_setup); - if (MY_RANK > 0) - { + const ControlModule::ControlSetup ctrl_setup = { + run_params.out_dir, // added + run_params.checkpoint_interval, + run_params.control_interval, + run_params.species_count, + run_params.species_names, + run_params.mape_threshold, + run_params.rrmse_threshold}; + + control.enableControlLogic(ctrl_setup); + + + if (MY_RANK > 0) { chemistry.WorkerLoop(); - } - else - { + } else { // R.parseEvalQ("mysetup <- setup"); // // if (MY_RANK == 0) { // get timestep vector from // // grid_init function ... // @@ -749,8 +698,7 @@ int main(int argc, char *argv[]) R["out_ext"] = run_params.out_ext; R["out_dir"] = run_params.out_dir; - if (run_params.use_ai_surrogate) - { + if (run_params.use_ai_surrogate) { /* Incorporate ai surrogate from R */ R.parseEvalQ(ai_surrogate_r_library); /* Use dht species for model input and output */ @@ -799,8 +747,7 @@ int main(int argc, char *argv[]) MPI_Finalize(); - if (MY_RANK == 0) - { + if (MY_RANK == 0) { MSG("done, bye!"); } diff --git a/src/poet.hpp.in b/src/poet.hpp.in index 6f9f0fabf..aea51966e 100644 --- a/src/poet.hpp.in +++ b/src/poet.hpp.in @@ -38,8 +38,7 @@ static const inline std::string ai_surrogate_r_library = R"(@R_AI_SURROGATE_LIB@)"; static const inline std::string r_runtime_parameters = "mysetup"; -struct RuntimeParameters -{ +struct RuntimeParameters { std::string out_dir; std::vector timesteps;