From 8c388e3c6ce86bd2775d2420104c981fc2b8e1d0 Mon Sep 17 00:00:00 2001 From: rastogi Date: Thu, 30 Oct 2025 16:15:27 +0100 Subject: [PATCH] Restructuring prototyp 3 --- src/Chemistry/ChemistryModule.hpp | 838 +++++++++++++++-------------- src/Chemistry/MasterFunctions.cpp | 151 +++--- src/Control/ControlModule.cpp | 193 +++++++ src/Control/ControlModule.hpp | 118 +++++ src/IO/StatsIO.cpp | 36 +- src/IO/StatsIO.hpp | 5 +- src/poet.cpp | 854 +++++++++++++----------------- 7 files changed, 1205 insertions(+), 990 deletions(-) create mode 100644 src/Control/ControlModule.cpp create mode 100644 src/Control/ControlModule.hpp diff --git a/src/Chemistry/ChemistryModule.hpp b/src/Chemistry/ChemistryModule.hpp index 836b0f237..0948e52a4 100644 --- a/src/Chemistry/ChemistryModule.hpp +++ b/src/Chemistry/ChemistryModule.hpp @@ -23,434 +23,428 @@ #include #include -namespace poet -{ +namespace poet { +/** + * \brief Wrapper around PhreeqcRM to provide POET specific parallelization with + * easy access. + */ +class ChemistryModule { +public: /** - * \brief Wrapper around PhreeqcRM to provide POET specific parallelization with - * easy access. + * Creates a new instance of Chemistry module with given grid cell count, work + * package size and communicator. + * + * This constructor shall only be called by the master. To create workers, see + * ChemistryModule::createWorker . + * + * When the use of parallelization is intended, the nxyz value shall be set to + * 1 to save memory and only one node is needed for initialization. + * + * \param nxyz Count of grid cells to allocate and initialize for each + * process. For parellel use set to 1 at the master. + * \param wp_size Count of grid cells to fill each work package at maximum. + * \param communicator MPI communicator to distribute work in. */ - class ChemistryModule - { - public: - /** - * Creates a new instance of Chemistry module with given grid cell count, work - * package size and communicator. - * - * This constructor shall only be called by the master. To create workers, see - * ChemistryModule::createWorker . - * - * When the use of parallelization is intended, the nxyz value shall be set to - * 1 to save memory and only one node is needed for initialization. - * - * \param nxyz Count of grid cells to allocate and initialize for each - * process. For parellel use set to 1 at the master. - * \param wp_size Count of grid cells to fill each work package at maximum. - * \param communicator MPI communicator to distribute work in. - */ - ChemistryModule(uint32_t wp_size, - const InitialList::ChemistryInit chem_params, - MPI_Comm communicator); - - /** - * Deconstructor, which frees DHT data structure if used. - */ - ~ChemistryModule(); - - void masterSetField(Field field); - /** - * Run the chemical simulation with parameters set. - */ - void simulate(double dt); - - /** - * Returns all known species names, including not only aqueous species, but - * also equilibrium, exchange, surface and kinetic reactants. - */ - // auto GetPropNames() const { return this->prop_names; } - - /** - * Return the accumulated runtime in seconds for chemical simulation. - */ - auto GetChemistryTime() const { return this->chem_t; } - - void setFilePadding(std::uint32_t maxiter) - { - this->file_pad = - static_cast(std::ceil(std::log10(maxiter + 1))); - } - - struct SurrogateSetup - { - std::vector prop_names; - std::array base_totals; - bool has_het_ids; - - bool dht_enabled; - std::uint32_t dht_size_mb; - int dht_snaps; - std::string dht_out_dir; - - bool interp_enabled; - std::uint32_t interp_bucket_size; - std::uint32_t interp_size_mb; - std::uint32_t interp_min_entries; - bool ai_surrogate_enabled; - }; - - void masterEnableSurrogates(const SurrogateSetup &setup) - { - // FIXME: This is a hack to get the prop_names and prop_count from the setup - this->prop_names = setup.prop_names; - this->prop_count = setup.prop_names.size(); - - this->dht_enabled = setup.dht_enabled; - this->interp_enabled = setup.interp_enabled; - this->ai_surrogate_enabled = setup.ai_surrogate_enabled; - - this->base_totals = setup.base_totals; - - if (this->dht_enabled || this->interp_enabled) - { - this->initializeDHT(setup.dht_size_mb, this->params.dht_species, - setup.has_het_ids); - - if (setup.dht_snaps != DHT_SNAPS_DISABLED) - { - this->setDHTSnapshots(setup.dht_snaps, setup.dht_out_dir); - } - } - - if (this->interp_enabled) - { - this->initializeInterp(setup.interp_bucket_size, setup.interp_size_mb, - setup.interp_min_entries, - this->params.interp_species); - } - } - - /** - * Intended to alias input parameters for grid initialization with a single - * value per species. - */ - using SingleCMap = std::unordered_map; - - /** - * Intended to alias input parameters for grid initialization with mutlitple - * values per species. - */ - using VectorCMap = std::unordered_map>; - - /** - * Enumerating DHT file options - */ - enum - { - DHT_SNAPS_DISABLED = 0, //!< disabled file output - DHT_SNAPS_SIMEND, //!< only output of snapshot after simulation - DHT_SNAPS_ITEREND //!< output snapshots after each iteration - }; - - /** - * **Only called by workers!** Start the worker listening loop. - */ - void WorkerLoop(); - - /** - * **Called by master** Advise the workers to break the loop. - */ - void MasterLoopBreak(); - - /** - * **Master only** Return count of grid cells. - */ - auto GetNCells() const { return this->n_cells; } - /** - * **Master only** Return work package size. - */ - auto GetWPSize() const { return this->wp_size; } - /** - * **Master only** Return the time in seconds the master spent waiting for any - * free worker. - */ - auto GetMasterIdleTime() const { return this->idle_t; } - /** - * **Master only** Return the time in seconds the master spent in sequential - * part of the simulation, including times for shuffling/unshuffling field - * etc. - */ - auto GetMasterSequentialTime() const { return this->seq_t; } - /** - * **Master only** Return the time in seconds the master spent in the - * send/receive loop. - */ - auto GetMasterLoopTime() const { return this->send_recv_t; } - - /** - * **Master only** Collect and return all accumulated timings recorded by - * workers to run Phreeqc simulation. - * - * \return Vector of all accumulated Phreeqc timings. - */ - std::vector GetWorkerPhreeqcTimings() const; - /** - * **Master only** Collect and return all accumulated timings recorded by - * workers to get values from the DHT. - * - * \return Vector of all accumulated DHT get times. - */ - std::vector GetWorkerDHTGetTimings() const; - /** - * **Master only** Collect and return all accumulated timings recorded by - * workers to write values to the DHT. - * - * \return Vector of all accumulated DHT fill times. - */ - std::vector GetWorkerDHTFillTimings() const; - /** - * **Master only** Collect and return all accumulated timings recorded by - * workers waiting for work packages from the master. - * - * \return Vector of all accumulated waiting times. - */ - std::vector GetWorkerIdleTimings() const; - - /** - * **Master only** Collect and return DHT hits of all workers. - * - * \return Vector of all count of DHT hits. - */ - std::vector GetWorkerDHTHits() const; - - /** - * **Master only** Collect and return DHT evictions of all workers. - * - * \return Vector of all count of DHT evictions. - */ - std::vector GetWorkerDHTEvictions() const; - - /** - * **Master only** Returns the current state of the chemical field. - * - * \return Reference to the chemical field. - */ - Field &getField() { return this->chem_field; } - - /** - * **Master only** Enable/disable progress bar. - * - * \param enabled True if print progressbar, false if not. - */ - void setProgressBarPrintout(bool enabled) - { - this->print_progessbar = enabled; - }; - - /** - * **Master only** Set the ai surrogate validity vector from R - */ - void set_ai_surrogate_validity_vector(std::vector r_vector); - - std::vector GetWorkerInterpolationCalls() const; - - std::vector GetWorkerInterpolationWriteTimings() const; - std::vector GetWorkerInterpolationReadTimings() const; - std::vector GetWorkerInterpolationGatherTimings() const; - std::vector GetWorkerInterpolationFunctionCallTimings() const; - - std::vector GetWorkerPHTCacheHits() const; - - std::vector ai_surrogate_validity_vector; - - RuntimeParameters *runtime_params = nullptr; - uint32_t control_iteration_counter = 0; - - struct error_stats - { - std::vector mape; - std::vector rrsme; - uint32_t iteration; - - error_stats(size_t species_count, size_t iter) - : mape(species_count, 0.0), rrsme(species_count, 0.0), iteration(iter) {} - }; - - std::vector error_stats_history; - - static void computeStats(const std::vector &pqc_vector, - const std::vector &sur_vector, - uint32_t size_per_prop, uint32_t species_count, - error_stats &stats); - - protected: - void initializeDHT(uint32_t size_mb, - const NamedVector &key_species, - bool has_het_ids); - void setDHTSnapshots(int type, const std::string &out_dir); - void setDHTReadFile(const std::string &input_file); - - void initializeInterp(std::uint32_t bucket_size, std::uint32_t size_mb, - std::uint32_t min_entries, - const NamedVector &key_species); - - enum - { - CHEM_FIELD_INIT, - CHEM_DHT_ENABLE, - CHEM_DHT_SIGNIF_VEC, - CHEM_DHT_SNAPS, - CHEM_DHT_READ_FILE, - CHEM_INTERP, - CHEM_IP_ENABLE, - CHEM_IP_MIN_ENTRIES, - CHEM_IP_SIGNIF_VEC, - CHEM_WORK_LOOP, - CHEM_PERF, - CHEM_BREAK_MAIN_LOOP, - CHEM_AI_BCAST_VALIDITY - }; - - enum - { - LOOP_WORK, - LOOP_END, - LOOP_CTRL - }; - - enum - { - WORKER_PHREEQC, - WORKER_DHT_GET, - WORKER_DHT_FILL, - WORKER_IDLE, - WORKER_IP_WRITE, - WORKER_IP_READ, - WORKER_IP_GATHER, - WORKER_IP_FC, - WORKER_DHT_HITS, - WORKER_DHT_EVICTIONS, - WORKER_PHT_CACHE_HITS, - WORKER_IP_CALLS - }; - - std::vector interp_calls; - std::vector dht_hits; - std::vector dht_evictions; - - struct worker_s - { - double phreeqc_t = 0.; - double dht_get = 0.; - double dht_fill = 0.; - double idle_t = 0.; - }; - - struct worker_info_s - { - char has_work = 0; - double *send_addr; - double *surrogate_addr; - }; - - using worker_list_t = std::vector; - using workpointer_t = std::vector::iterator; - - void MasterRunParallel(double dt); - void MasterRunSequential(); - - void MasterSendPkgs(worker_list_t &w_list, workpointer_t &work_pointer, workpointer_t &sur_pointer, - int &pkg_to_send, int &count_pkgs, int &free_workers, - double dt, uint32_t iteration, uint32_t control_iteration, - const std::vector &wp_sizes_vector); - void MasterRecvPkgs(worker_list_t &w_list, int &pkg_to_recv, bool to_send, - int &free_workers); - - std::vector MasterGatherWorkerTimings(int type) const; - std::vector MasterGatherWorkerMetrics(int type) const; - - void WorkerProcessPkgs(struct worker_s &timings, uint32_t &iteration); - - void WorkerDoWork(MPI_Status &probe_status, int double_count, - struct worker_s &timings); - void WorkerPostIter(MPI_Status &prope_status, uint32_t iteration); - void WorkerPostSim(uint32_t iteration); - - void WorkerWriteDHTDump(uint32_t iteration); - void WorkerReadDHTDump(const std::string &dht_input_file); - - void WorkerPerfToMaster(int type, const struct worker_s &timings); - void WorkerMetricsToMaster(int type); - - void WorkerRunWorkPackage(WorkPackage &work_package, double dSimTime, - double dTimestep); - - std::vector CalculateWPSizesVector(uint32_t n_cells, - uint32_t wp_size) const; - std::vector shuffleField(const std::vector &in_field, - uint32_t size_per_prop, uint32_t prop_count, - uint32_t wp_count); - void unshuffleField(const std::vector &in_buffer, - uint32_t size_per_prop, uint32_t prop_count, - uint32_t wp_count, std::vector &out_field); - std::vector - parseDHTSpeciesVec(const NamedVector &key_species, - const std::vector &to_compare) const; - - void BCastStringVec(std::vector &io); - - int comm_size, comm_rank; - MPI_Comm group_comm; - - bool is_sequential; - bool is_master; - - uint32_t wp_size; - bool dht_enabled{false}; - int dht_snaps_type{DHT_SNAPS_DISABLED}; - std::string dht_file_out_dir; - - poet::DHT_Wrapper *dht = nullptr; - - bool interp_enabled{false}; - std::unique_ptr interp; - - bool ai_surrogate_enabled{false}; - - static constexpr uint32_t BUFFER_OFFSET = 6; - - inline void ChemBCast(void *buf, int count, MPI_Datatype datatype) const - { - MPI_Bcast(buf, count, datatype, 0, this->group_comm); - } - - inline void PropagateFunctionType(int &type) const - { - ChemBCast(&type, 1, MPI_INT); - } - double simtime = 0.; - double idle_t = 0.; - double seq_t = 0.; - double send_recv_t = 0.; - - std::array base_totals{0}; - - bool print_progessbar{false}; - - std::uint8_t file_pad{1}; - - double chem_t{0.}; - - uint32_t n_cells = 0; - uint32_t prop_count = 0; + ChemistryModule(uint32_t wp_size, + const InitialList::ChemistryInit chem_params, + MPI_Comm communicator); + + /** + * Deconstructor, which frees DHT data structure if used. + */ + ~ChemistryModule(); + + void masterSetField(Field field); + /** + * Run the chemical simulation with parameters set. + */ + void simulate(double dt); + + /** + * Returns all known species names, including not only aqueous species, but + * also equilibrium, exchange, surface and kinetic reactants. + */ + // auto GetPropNames() const { return this->prop_names; } + + /** + * Return the accumulated runtime in seconds for chemical simulation. + */ + auto GetChemistryTime() const { return this->chem_t; } + + void setFilePadding(std::uint32_t maxiter) { + this->file_pad = + static_cast(std::ceil(std::log10(maxiter + 1))); + } + + struct SurrogateSetup { std::vector prop_names; + std::array base_totals; + bool has_het_ids; - Field chem_field; + bool dht_enabled; + std::uint32_t dht_size_mb; + int dht_snaps; + std::string dht_out_dir; - const InitialList::ChemistryInit params; - - std::unique_ptr pqc_runner; - - std::vector sur_shuffled; + bool interp_enabled; + std::uint32_t interp_bucket_size; + std::uint32_t interp_size_mb; + std::uint32_t interp_min_entries; + bool ai_surrogate_enabled; }; + + void masterEnableSurrogates(const SurrogateSetup &setup) { + // FIXME: This is a hack to get the prop_names and prop_count from the setup + this->prop_names = setup.prop_names; + this->prop_count = setup.prop_names.size(); + + this->dht_enabled = setup.dht_enabled; + this->interp_enabled = setup.interp_enabled; + this->ai_surrogate_enabled = setup.ai_surrogate_enabled; + + this->base_totals = setup.base_totals; + + if (this->dht_enabled || this->interp_enabled) { + this->initializeDHT(setup.dht_size_mb, this->params.dht_species, + setup.has_het_ids); + + if (setup.dht_snaps != DHT_SNAPS_DISABLED) { + this->setDHTSnapshots(setup.dht_snaps, setup.dht_out_dir); + } + } + + if (this->interp_enabled) { + this->initializeInterp(setup.interp_bucket_size, setup.interp_size_mb, + setup.interp_min_entries, + this->params.interp_species); + } + } + + /** + * Intended to alias input parameters for grid initialization with a single + * value per species. + */ + using SingleCMap = std::unordered_map; + + /** + * Intended to alias input parameters for grid initialization with mutlitple + * values per species. + */ + using VectorCMap = std::unordered_map>; + + /** + * Enumerating DHT file options + */ + enum { + DHT_SNAPS_DISABLED = 0, //!< disabled file output + DHT_SNAPS_SIMEND, //!< only output of snapshot after simulation + DHT_SNAPS_ITEREND //!< output snapshots after each iteration + }; + + /** + * **Only called by workers!** Start the worker listening loop. + */ + void WorkerLoop(); + + /** + * **Called by master** Advise the workers to break the loop. + */ + void MasterLoopBreak(); + + /** + * **Master only** Return count of grid cells. + */ + auto GetNCells() const { return this->n_cells; } + /** + * **Master only** Return work package size. + */ + auto GetWPSize() const { return this->wp_size; } + /** + * **Master only** Return the time in seconds the master spent waiting for any + * free worker. + */ + auto GetMasterIdleTime() const { return this->idle_t; } + /** + * **Master only** Return the time in seconds the master spent in sequential + * part of the simulation, including times for shuffling/unshuffling field + * etc. + */ + auto GetMasterSequentialTime() const { return this->seq_t; } + /** + * **Master only** Return the time in seconds the master spent in the + * send/receive loop. + */ + auto GetMasterLoopTime() const { return this->send_recv_t; } + + auto GetMasterRecvCtrlDataTime() const { return this->recv_ctrl_t; } + + auto GetMasterUnshuffleTime() const { return this->shuf_t; } + + auto GetMasterCtrlMetricsTime() const { return this->metrics_t; } + + /** + * **Master only** Collect and return all accumulated timings recorded by + * workers to run Phreeqc simulation. + * + * \return Vector of all accumulated Phreeqc timings. + */ + std::vector GetWorkerPhreeqcTimings() const; + /** + * **Master only** Collect and return all accumulated timings recorded by + * workers to get values from the DHT. + * + * \return Vector of all accumulated DHT get times. + */ + std::vector GetWorkerDHTGetTimings() const; + /** + * **Master only** Collect and return all accumulated timings recorded by + * workers to write values to the DHT. + * + * \return Vector of all accumulated DHT fill times. + */ + std::vector GetWorkerDHTFillTimings() const; + /** + * **Master only** Collect and return all accumulated timings recorded by + * workers waiting for work packages from the master. + * + * \return Vector of all accumulated waiting times. + */ + std::vector GetWorkerIdleTimings() const; + + /** + * **Master only** Collect and return DHT hits of all workers. + * + * \return Vector of all count of DHT hits. + */ + std::vector GetWorkerDHTHits() const; + + /** + * **Master only** Collect and return DHT evictions of all workers. + * + * \return Vector of all count of DHT evictions. + */ + std::vector GetWorkerDHTEvictions() const; + + /** + * **Master only** Returns the current state of the chemical field. + * + * \return Reference to the chemical field. + */ + Field &getField() { return this->chem_field; } + + /** + * **Master only** Enable/disable progress bar. + * + * \param enabled True if print progressbar, false if not. + */ + void setProgressBarPrintout(bool enabled) { + this->print_progessbar = enabled; + }; + + /** + * **Master only** Set the ai surrogate validity vector from R + */ + void set_ai_surrogate_validity_vector(std::vector r_vector); + + std::vector GetWorkerInterpolationCalls() const; + + std::vector GetWorkerInterpolationWriteTimings() const; + std::vector GetWorkerInterpolationReadTimings() const; + std::vector GetWorkerInterpolationGatherTimings() const; + std::vector GetWorkerInterpolationFunctionCallTimings() const; + + std::vector GetWorkerPHTCacheHits() const; + + std::vector ai_surrogate_validity_vector; + + RuntimeParameters *runtime_params = nullptr; + uint32_t control_iteration_counter = 0; + + struct error_stats { + std::vector mape; + std::vector rrsme; + uint32_t iteration; + + error_stats(size_t species_count, size_t iter) + : mape(species_count, 0.0), rrsme(species_count, 0.0), iteration(iter) { + } + }; + + std::vector error_stats_history; + + static void computeStats(const std::vector &pqc_vector, + const std::vector &sur_vector, + uint32_t size_per_prop, uint32_t species_count, + error_stats &stats); + +protected: + void initializeDHT(uint32_t size_mb, + const NamedVector &key_species, + bool has_het_ids); + void setDHTSnapshots(int type, const std::string &out_dir); + void setDHTReadFile(const std::string &input_file); + + void initializeInterp(std::uint32_t bucket_size, std::uint32_t size_mb, + std::uint32_t min_entries, + const NamedVector &key_species); + + enum { + CHEM_FIELD_INIT, + CHEM_DHT_ENABLE, + CHEM_DHT_SIGNIF_VEC, + CHEM_DHT_SNAPS, + CHEM_DHT_READ_FILE, + CHEM_INTERP, + CHEM_IP_ENABLE, + CHEM_IP_MIN_ENTRIES, + CHEM_IP_SIGNIF_VEC, + CHEM_WORK_LOOP, + CHEM_PERF, + CHEM_BREAK_MAIN_LOOP, + CHEM_AI_BCAST_VALIDITY + }; + + enum { LOOP_WORK, LOOP_END, LOOP_CTRL }; + + enum { + WORKER_PHREEQC, + WORKER_DHT_GET, + WORKER_DHT_FILL, + WORKER_IDLE, + WORKER_IP_WRITE, + WORKER_IP_READ, + WORKER_IP_GATHER, + WORKER_IP_FC, + WORKER_DHT_HITS, + WORKER_DHT_EVICTIONS, + WORKER_PHT_CACHE_HITS, + WORKER_IP_CALLS + }; + + std::vector interp_calls; + std::vector dht_hits; + std::vector dht_evictions; + + struct worker_s { + double phreeqc_t = 0.; + double dht_get = 0.; + double dht_fill = 0.; + double idle_t = 0.; + }; + + struct worker_info_s { + char has_work = 0; + double *send_addr; + double *surrogate_addr; + }; + + using worker_list_t = std::vector; + using workpointer_t = std::vector::iterator; + + void MasterRunParallel(double dt); + void MasterRunSequential(); + + void MasterSendPkgs(worker_list_t &w_list, workpointer_t &work_pointer, + workpointer_t &sur_pointer, int &pkg_to_send, + int &count_pkgs, int &free_workers, double dt, + uint32_t iteration, uint32_t control_iteration, + const std::vector &wp_sizes_vector); + void MasterRecvPkgs(worker_list_t &w_list, int &pkg_to_recv, bool to_send, + int &free_workers); + + std::vector MasterGatherWorkerTimings(int type) const; + std::vector MasterGatherWorkerMetrics(int type) const; + + void WorkerProcessPkgs(struct worker_s &timings, uint32_t &iteration); + + void WorkerDoWork(MPI_Status &probe_status, int double_count, + struct worker_s &timings); + void WorkerPostIter(MPI_Status &prope_status, uint32_t iteration); + void WorkerPostSim(uint32_t iteration); + + void WorkerWriteDHTDump(uint32_t iteration); + void WorkerReadDHTDump(const std::string &dht_input_file); + + void WorkerPerfToMaster(int type, const struct worker_s &timings); + void WorkerMetricsToMaster(int type); + + void WorkerRunWorkPackage(WorkPackage &work_package, double dSimTime, + double dTimestep); + + std::vector CalculateWPSizesVector(uint32_t n_cells, + uint32_t wp_size) const; + std::vector shuffleField(const std::vector &in_field, + uint32_t size_per_prop, uint32_t prop_count, + uint32_t wp_count); + void unshuffleField(const std::vector &in_buffer, + uint32_t size_per_prop, uint32_t prop_count, + uint32_t wp_count, std::vector &out_field); + std::vector + parseDHTSpeciesVec(const NamedVector &key_species, + const std::vector &to_compare) const; + + void BCastStringVec(std::vector &io); + + int comm_size, comm_rank; + MPI_Comm group_comm; + + bool is_sequential; + bool is_master; + + uint32_t wp_size; + bool dht_enabled{false}; + int dht_snaps_type{DHT_SNAPS_DISABLED}; + std::string dht_file_out_dir; + + poet::DHT_Wrapper *dht = nullptr; + + bool interp_enabled{false}; + std::unique_ptr interp; + + bool ai_surrogate_enabled{false}; + + static constexpr uint32_t BUFFER_OFFSET = 6; + + inline void ChemBCast(void *buf, int count, MPI_Datatype datatype) const { + MPI_Bcast(buf, count, datatype, 0, this->group_comm); + } + + inline void PropagateFunctionType(int &type) const { + ChemBCast(&type, 1, MPI_INT); + } + double simtime = 0.; + double idle_t = 0.; + double seq_t = 0.; + double send_recv_t = 0.; + + double recv_ctrl_t = 0.; + double shuf_t = 0.; + double metrics_t = 0. + + std::array + base_totals{0}; + + bool print_progessbar{false}; + + std::uint8_t file_pad{1}; + + double chem_t{0.}; + + uint32_t n_cells = 0; + uint32_t prop_count = 0; + std::vector prop_names; + + Field chem_field; + + const InitialList::ChemistryInit params; + + std::unique_ptr pqc_runner; + + poet::ControlModule *control_module = nullptr; + + bool control_enabled{false}; + bool warmup_enabled{false}; +}; } // namespace poet #endif // CHEMISTRYMODULE_H_ diff --git a/src/Chemistry/MasterFunctions.cpp b/src/Chemistry/MasterFunctions.cpp index 683985134..06e5e9c02 100644 --- a/src/Chemistry/MasterFunctions.cpp +++ b/src/Chemistry/MasterFunctions.cpp @@ -160,39 +160,6 @@ std::vector poet::ChemistryModule::GetWorkerPHTCacheHits() const { return ret; } -void poet::ChemistryModule::computeStats(const std::vector &pqc_vector, - const std::vector &sur_vector, - uint32_t size_per_prop, - uint32_t species_count, - error_stats &stats) { - for (uint32_t i = 0; i < species_count; ++i) { - double err_sum = 0.0; - double sqr_err_sum = 0.0; - uint32_t base_idx = i * size_per_prop; - - for (uint32_t j = 0; j < size_per_prop; ++j) { - const double pqc_value = pqc_vector[base_idx + j]; - const double sur_value = sur_vector[base_idx + j]; - - if (pqc_value == 0.0) { - if (sur_value != 0.0) { - err_sum += 1.0; - sqr_err_sum += 1.0; - } - // Both zero: skip - } else { - double alpha = 1.0 - (sur_value / pqc_value); - err_sum += std::abs(alpha); - sqr_err_sum += alpha * alpha; - } - } - - stats.mape[i] = 100.0 * (err_sum / size_per_prop); - stats.rrsme[i] = - (size_per_prop > 0) ? std::sqrt(sqr_err_sum / size_per_prop) : 0.0; - } -} - inline std::vector shuffleVector(const std::vector &in_vector, uint32_t size_per_prop, uint32_t wp_count) { @@ -264,7 +231,7 @@ inline void poet::ChemistryModule::MasterSendPkgs( worker_list_t &w_list, workpointer_t &work_pointer, workpointer_t &sur_pointer, int &pkg_to_send, int &count_pkgs, int &free_workers, double dt, uint32_t iteration, - uint32_t control_iteration, const std::vector &wp_sizes_vector) { + const std::vector &wp_sizes_vector) { /* declare variables */ int local_work_package_size; @@ -278,6 +245,10 @@ inline void poet::ChemistryModule::MasterSendPkgs( local_work_package_size = (int)wp_sizes_vector[count_pkgs]; count_pkgs++; + uint32_t wp_start_index = + std::accumulate(wp_sizes_vector.begin(), + std::next(wp_sizes_vector.begin(), count_pkgs), 0); + /* note current processed work package in workerlist */ w_list[p].send_addr = work_pointer.base(); w_list[p].surrogate_addr = sur_pointer.base(); @@ -300,12 +271,12 @@ inline void poet::ChemistryModule::MasterSendPkgs( // current time of simulation (age) in seconds send_buffer[end_of_wp + 3] = this->simtime; // current work package start location in field - uint32_t wp_start_index = - std::accumulate(wp_sizes_vector.begin(), - std::next(wp_sizes_vector.begin(), count_pkgs), 0); send_buffer[end_of_wp + 4] = wp_start_index; - // whether this iteration is a control iteration - send_buffer[end_of_wp + 5] = control_iteration; + // control flags (bitmask) + int flags = (this->interp_enabled ? 1 : 0) | (this->dht_enabled ? 2 : 0) | + (this->warmup_enabled ? 4 : 0) | + (this->control_enabled ? 8 : 0); + send_buffer[end_of_wp + 5] = static_cast(flags); /* ATTENTION Worker p has rank p+1 */ // MPI_Send(send_buffer, end_of_wp + BUFFER_OFFSET, MPI_DOUBLE, p + 1, @@ -328,8 +299,11 @@ inline void poet::ChemistryModule::MasterRecvPkgs(worker_list_t &w_list, /* declare most of the variables here */ int need_to_receive = 1; double idle_a, idle_b; + double recv_ctrl_a, recv_ctrl_b; int p, size; + std::vector recv_buffer; + recv_buffer.reserve(wp_size * prop_count * 2); MPI_Status probe_status; // master_recv_a = MPI_Wtime(); /* start to loop as long there are packages to recv and the need to receive @@ -347,38 +321,51 @@ inline void poet::ChemistryModule::MasterRecvPkgs(worker_list_t &w_list, idle_b = MPI_Wtime(); this->idle_t += idle_b - idle_a; } - + if (!need_to_receive) { + continue; + } /* if need_to_receive was set to true above, so there is a message to * receive */ - if (need_to_receive) { - p = probe_status.MPI_SOURCE; - if (probe_status.MPI_TAG == LOOP_WORK) { - MPI_Get_count(&probe_status, MPI_DOUBLE, &size); - MPI_Recv(w_list[p - 1].send_addr, size, MPI_DOUBLE, p, LOOP_WORK, - this->group_comm, MPI_STATUS_IGNORE); - w_list[p - 1].has_work = 0; - pkg_to_recv -= 1; - free_workers++; - } - if (probe_status.MPI_TAG == LOOP_CTRL) { - MPI_Get_count(&probe_status, MPI_DOUBLE, &size); + p = probe_status.MPI_SOURCE; + bool handled = false; - // layout of buffer is [phreeqc][surrogate] - std::vector recv_buffer(size); + switch (probe_status.MPI_TAG) { + case LOOP_WORK: { + MPI_Get_count(&probe_status, MPI_DOUBLE, &size); + MPI_Recv(w_list[p - 1].send_addr, size, MPI_DOUBLE, p, LOOP_WORK, + this->group_comm, MPI_STATUS_IGNORE); + handled = true; + break; + } + case LOOP_CTRL: { + recv_ctrl_a = MPI_Wtime(); + /* layout of buffer is [phreeqc][surrogate] */ + MPI_Get_count(&probe_status, MPI_DOUBLE, &size); + recv_buffer.resize(size); + MPI_Recv(recv_buffer.data(), size, MPI_DOUBLE, p, LOOP_CTRL, + this->group_comm, MPI_STATUS_IGNORE); - MPI_Recv(recv_buffer.data(), size, MPI_DOUBLE, p, LOOP_CTRL, - this->group_comm, MPI_STATUS_IGNORE); + int half = size / 2; + std::copy(recv_buffer.begin(), recv_buffer.begin() + half, + w_list[p - 1].send_addr); - std::copy(recv_buffer.begin(), recv_buffer.begin() + (size / 2), - w_list[p - 1].send_addr); + std::copy(recv_buffer.begin() + (size / 2), recv_buffer.begin() + size, + w_list[p - 1].surrogate_addr); + recv_ctrl_b = MPI_Wtime(); + recv_ctrl_t += recv_ctrl_b - recv_ctrl_a; - std::copy(recv_buffer.begin() + (size / 2), recv_buffer.begin() + size, - w_list[p - 1].surrogate_addr); - - w_list[p - 1].has_work = 0; - pkg_to_recv -= 1; - free_workers++; - } + handled = true; + break; + } + default: { + throw std::runtime_error("Master received unknown MPI tag: " + + std::to_string(probe_status.MPI_TAG)); + } + } + if (handled) { + w_list[p - 1].has_work = 0; + pkg_to_recv -= 1; + free_workers++; } } } @@ -451,7 +438,7 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { ftype = CHEM_INTERP; PropagateFunctionType(ftype); - if(this->runtime_params->rollback_simulation){ + if (this->runtime_params->rollback_simulation) { this->interp_enabled = false; int interp_flag = 0; ChemBCast(&interp_flag, 1, MPI_INT); @@ -538,24 +525,32 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { /* do master stuff */ - if (control_iteration) { - control_iteration_counter++; + /* do master stuff */ + if (control_enabled) { + std::cout << "[Master] Control logic enabled for this iteration." + << std::endl; + std::vector sur_unshuffled{mpi_surr_buffer}; - std::vector sur_unshuffled{sur_shuffled}; - - unshuffleField(sur_shuffled, this->n_cells, this->prop_count, + shuf_a = MPI_Wtime(); + unshuffleField(mpi_surr_buffer, this->n_cells, this->prop_count, wp_sizes_vector.size(), sur_unshuffled); + shuf_b = MPI_Wtime(); + this->shuf_t += shuf_b - shuf_a; - error_stats stats(this->prop_count, control_iteration_counter * - runtime_params->control_iteration); + size_t N = out_vec.size(); + if (N != sur_unshuffled.size()) { + std::cerr << "[MASTER DBG] size mismatch out_vec=" << N + << " sur_unshuffled=" << sur_unshuffled.size() << std::endl; + } - computeStats(out_vec, sur_unshuffled, this->n_cells, this->prop_count, - stats); - error_stats_history.push_back(stats); - - // to do: control values to epsilon + metrics_a = MPI_Wtime(); + control_module->computeSpeciesErrorMetrics(out_vec, sur_unshuffled, + this->n_cells); + metrics_b = MPI_Wtime(); + this->metrics_t += metrics_b - metrics_a; } + /* start time measurement of master chemistry */ sim_e_chemistry = MPI_Wtime(); diff --git a/src/Control/ControlModule.cpp b/src/Control/ControlModule.cpp new file mode 100644 index 000000000..db12c3025 --- /dev/null +++ b/src/Control/ControlModule.cpp @@ -0,0 +1,193 @@ +#include "ControlModule.hpp" +#include "IO/Datatypes.hpp" +#include "IO/HDF5Functions.hpp" +#include "IO/StatsIO.hpp" +#include + +void poet::ControlModule::updateControlIteration(const uint32_t &iter, + const bool &dht_enabled, + const bool &interp_enabled) { + + /* dht_enabled and inter_enabled are user settings set before startig the + * simulation*/ + double prep_a, prep_b; + + prep_a = MPI_Wtime(); + if (control_interval == 0) { + control_interval_enabled = false; + return; + } + global_iteration = iter; + initiateWarmupPhase(dht_enabled, interp_enabled); + + control_interval_enabled = + (control_interval > 0 && iter % control_interval == 0); + + if (control_interval_enabled) { + MSG("[Control] Control interval enabled at iteration " + + std::to_string(iter)); + } + prep_b = MPI_Wtime(); + this->prep_t += prep_b - prep_a; +} + +void poet::ControlModule::initiateWarmupPhase(bool dht_enabled, + bool interp_enabled) { + + // user requested DHT/INTEP? keep them disabled but enable warmup-phase so + if (global_iteration <= control_interval || rollback_enabled) { + chem->SetWarmupEnabled(true); + chem->SetDhtEnabled(false); + chem->SetInterpEnabled(false); + MSG("Warmup enabled until next control interval at iteration " + + std::to_string(control_interval) + "."); + + if (rollback_enabled) { + if (sur_disabled_counter > 0) { + --sur_disabled_counter; + MSG("Rollback counter: " + std::to_string(sur_disabled_counter)); + } else { + rollback_enabled = false; + } + } + return; + } + + chem->SetWarmupEnabled(false); + chem->SetDhtEnabled(dht_enabled); + chem->SetInterpEnabled(interp_enabled); +} + +void poet::ControlModule::applyControlLogic(ChemistryModule &chem, + uint32_t &iter) { + if (!control_interval_enabled) { + return; + } + writeCheckpointAndMetrics(chem, iter); + + if (checkAndRollback(chem, iter) && rollback_count < 4) { + rollback_enabled = true; + rollback_count++; + sur_disabled_counter = control_interval; + MSG("Interpolation disabled for the next " + + std::to_string(control_interval) + "."); + } +} + +void poet::ControlModule::writeCheckpointAndMetrics(ChemistryModule &chem, + uint32_t iter) { + + double w_check_a, w_check_b, stats_a, stats_b; + MSG("Writing checkpoint of iteration " + std::to_string(iter)); + + w_check_a = MPI_Wtime(); + write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5", + {.field = chem.getField(), .iteration = iter}); + w_check_b = MPI_Wtime(); + this->w_check_t += w_check_b - w_check_a; + + stats_a = MPI_Wtime(); + writeStatsToCSV(metricsHistory, species_names, out_dir, "stats_overview"); + stats_b = MPI_Wtime(); + + this->stats_t += stats_b - stats_a; +} + +bool poet::ControlModule::checkAndRollback(ChemistryModule &chem, + uint32_t &iter) { + double r_check_a, r_check_b; + + if (metricsHistory.empty()) { + MSG("No error history yet; skipping rollback check."); + return false; + } + + const auto &mape = metricsHistory.back().mape; + + for (uint32_t i = 0; i < species_names.size(); ++i) { + if (mape[i] == 0) { + continue; + } + + if (mape[i] > mape_threshold[i]) { + uint32_t rollback_iter = + ((iter - 1) / checkpoint_interval) * checkpoint_interval; + + MSG("[THRESHOLD EXCEEDED] " + species_names[i] + + " has MAPE = " + std::to_string(mape[i]) + + " exceeding threshold = " + std::to_string(mape_threshold[i]) + + " → rolling back to iteration " + std::to_string(rollback_iter)); + + r_check_a = MPI_Wtime(); + Checkpoint_s checkpoint_read{.field = chem.getField()}; + read_checkpoint(out_dir, + "checkpoint" + std::to_string(rollback_iter) + ".hdf5", + checkpoint_read); + iter = checkpoint_read.iteration; + r_check_b = MPI_Wtime(); + r_check_t += r_check_b - r_check_a; + return true; + } + } + MSG("All species are within their MAPE thresholds."); + + return false; +} + +void poet::ControlModule::computeSpeciesErrorMetrics( + const std::vector &reference_values, + const std::vector &surrogate_values, const uint32_t size_per_prop) { + + SpeciesErrorMetrics metrics(this->species_names.size(), global_iteration, + rollback_count); + + if (reference_values.size() != surrogate_values.size()) { + MSG(" Reference and surrogate vectors differ in size: " + + std::to_string(reference_values.size()) + " vs " + + std::to_string(surrogate_values.size())); + return; + } + + const std::size_t expected = + static_cast(this->species_names.size()) * size_per_prop; + if (reference_values.size() < expected) { + std::cerr << "[CTRL ERROR] input vectors too small: expected >= " + << expected << " entries, got " << reference_values.size() + << "\n"; + return; + } + + for (uint32_t i = 0; i < this->species_names.size(); ++i) { + double err_sum = 0.0; + double sqr_err_sum = 0.0; + uint32_t base_idx = i * size_per_prop; + + int count = 0; + + for (uint32_t j = 0; j < size_per_prop; ++j) { + const double ref_value = reference_values[base_idx + j]; + const double sur_value = surrogate_values[base_idx + j]; + const double ZERO_ABS = 1e-13; + + if (std::isnan(ref_value) || std::isnan(sur_value)) { + continue; + } + + if (std::abs(ref_value) < ZERO_ABS) { + if (std::abs(sur_value) >= ZERO_ABS) { + err_sum += 1.0; + sqr_err_sum += 1.0; + } + } + // Both zero: skip + else { + double alpha = 1.0 - (sur_value / ref_value); + err_sum += std::abs(alpha); + sqr_err_sum += alpha * alpha; + } + } + metrics.mape[i] = 100.0 * (err_sum / size_per_prop); + metrics.rrmse[i] = std::sqrt(sqr_err_sum / size_per_prop); + } + metricsHistory.push_back(metrics); +} diff --git a/src/Control/ControlModule.hpp b/src/Control/ControlModule.hpp new file mode 100644 index 000000000..01c76ccb9 --- /dev/null +++ b/src/Control/ControlModule.hpp @@ -0,0 +1,118 @@ +#ifndef CONTROLMODULE_H_ +#define CONTROLMODULE_H_ + +#include "Base/Macros.hpp" +#include "Chemistry/ChemistryModule.hpp" +#include "poet.hpp" + +#include +#include +#include + +namespace poet { + +class ChemistryModule; + +class ControlModule { + +public: + /* Control configuration*/ + + // std::uint32_t global_iter = 0; + // std::uint32_t sur_disabled_counter = 0; + // std::uint32_t rollback_counter = 0; + + void updateControlIteration(const uint32_t &iter, const bool &dht_enabled, + const bool &interp_enaled); + + void initiateWarmupPhase(bool dht_enabled, bool interp_enabled); + + bool checkAndRollback(ChemistryModule &chem, uint32_t &iter); + + struct SpeciesErrorMetrics { + std::vector mape; + std::vector rrmse; + uint32_t iteration; // iterations in simulation after rollbacks + uint32_t rollback_count; + + SpeciesErrorMetrics(uint32_t species_count, uint32_t iter, uint32_t counter) + : mape(species_count, 0.0), rrmse(species_count, 0.0), iteration(iter), + rollback_count(counter) {} + }; + + void computeSpeciesErrorMetrics(const std::vector &reference_values, + const std::vector &surrogate_values, + const uint32_t size_per_prop); + + std::vector metricsHistory; + + struct ControlSetup { + std::string out_dir; + std::uint32_t checkpoint_interval; + std::uint32_t control_interval; + std::vector species_names; + std::vector mape_threshold; + }; + + void enableControlLogic(const ControlSetup &setup) { + this->out_dir = setup.out_dir; + this->checkpoint_interval = setup.checkpoint_interval; + this->control_interval = setup.control_interval; + this->species_names = setup.species_names; + this->mape_threshold = setup.mape_threshold; + } + + bool getControlIntervalEnabled() const { + return this->control_interval_enabled; + } + + void applyControlLogic(ChemistryModule &chem, uint32_t &iter); + + void writeCheckpointAndMetrics(ChemistryModule &chem, uint32_t iter); + + auto getGlobalIteration() const noexcept { return global_iteration; } + + void setChemistryModule(poet::ChemistryModule *c) { chem = c; } + + auto getControlInterval() const { return this->control_interval; } + + std::vector getMapeThreshold() const { return this->mape_threshold; } + + /* Profiling getters */ + + auto getUpdateCtrlLogicTime() const { return this->prep_t; } + + auto getWriteCheckpointTime() const { return this->w_check_t; } + + auto getReadCheckpointTime() const { return this->r_check_t; } + + auto getWriteMetricsTime() const { return this->stats_t; } + +private: + bool rollback_enabled = false; + bool control_interval_enabled = false; + + poet::ChemistryModule *chem = nullptr; + + std::uint32_t checkpoint_interval = 0; + std::uint32_t control_interval = 0; + std::uint32_t global_iteration = 0; + std::uint32_t rollback_count = 0; + std::uint32_t sur_disabled_counter = 0; + std::vector mape_threshold; + + std::vector species_names; + std::string out_dir; + + double prep_t = 0.; + double r_check_t = 0.; + double w_check_t = 0.; + double stats_t = 0.; + + /* Buffer for shuffled surrogate data */ + std::vector sur_shuffled; +}; + +} // namespace poet + +#endif // CONTROLMODULE_H_ \ No newline at end of file diff --git a/src/IO/StatsIO.cpp b/src/IO/StatsIO.cpp index 4312a46dd..1b7d58d0c 100644 --- a/src/IO/StatsIO.cpp +++ b/src/IO/StatsIO.cpp @@ -2,14 +2,19 @@ #include #include #include +#include +#include namespace poet { - void writeStatsToCSV(const std::vector &all_stats, + void writeStatsToCSV(const std::vector &all_stats, const std::vector &species_names, + const std::string &out_dir, const std::string &filename) { - std::ofstream out(filename); + std::filesystem::path full_path = std::filesystem::path(out_dir) / filename; + + std::ofstream out(full_path); if (!out.is_open()) { std::cerr << "Could not open " << filename << " !" << std::endl; @@ -17,21 +22,32 @@ namespace poet } // header - out << "Iteration, Species, MAPE, RRSME \n"; + out << std::left << std::setw(15) << "Iteration" + << std::setw(15) << "Rollback" + << std::setw(15) << "Species" + << std::setw(15) << "MAPE" + << std::setw(15) << "RRSME" << "\n"; + out << std::string(75, '-') << "\n"; + + // data rows for (size_t i = 0; i < all_stats.size(); ++i) { for (size_t j = 0; j < species_names.size(); ++j) { - out << all_stats[i].iteration << ",\t" - << species_names[j] << ",\t" - << all_stats[i].mape[j] << ",\t" - << all_stats[i].rrsme[j] << "\n"; + out << std::left + << std::setw(15) << all_stats[i].iteration + << std::setw(15) << all_stats[i].rollback_count + << std::setw(15) << species_names[j] + << std::setw(15) << all_stats[i].mape[j] + << std::setw(15) << all_stats[i].rrmse[j] + << "\n"; } - out << std::endl; + out << "\n"; } out.close(); - std::cout << "Stats written to " << filename << "\n"; + std::cout << "Error metrics written to " << out_dir << "/" << filename << "\n"; } -} // namespace poet \ No newline at end of file +} + // namespace poet \ No newline at end of file diff --git a/src/IO/StatsIO.hpp b/src/IO/StatsIO.hpp index a865cc64a..5333c4fd8 100644 --- a/src/IO/StatsIO.hpp +++ b/src/IO/StatsIO.hpp @@ -1,9 +1,10 @@ #include -#include "Chemistry/ChemistryModule.hpp" +#include "Control/ControlModule.hpp" namespace poet { - void writeStatsToCSV(const std::vector &all_stats, + void writeStatsToCSV(const std::vector &all_stats, const std::vector &species_names, + const std::string &out_dir, const std::string &filename); } // namespace poet diff --git a/src/poet.cpp b/src/poet.cpp index 0f558b5d7..5f7a92e82 100644 --- a/src/poet.cpp +++ b/src/poet.cpp @@ -25,10 +25,8 @@ #include "Base/RInsidePOET.hpp" #include "CLI/CLI.hpp" #include "Chemistry/ChemistryModule.hpp" +#include "Control/ControlModule.hpp" #include "DataStructures/Field.hpp" -#include "IO/Datatypes.hpp" -#include "IO/HDF5Functions.hpp" -#include "IO/StatsIO.hpp" #include "Init/InitialList.hpp" #include "Transport/DiffusionModule.hpp" @@ -68,8 +66,7 @@ static poet::DEFunc ReadRObj_R; static poet::DEFunc SaveRObj_R; static poet::DEFunc source_R; -static void init_global_functions(RInside &R) -{ +static void init_global_functions(RInside &R) { R.parseEval(kin_r_library); master_init_R = DEFunc("master_init"); master_iteration_end_R = DEFunc("master_iteration_end"); @@ -92,15 +89,9 @@ static void init_global_functions(RInside &R) // R.parseEval("mysetup$state_C <- TMP"); // } -enum ParseRet -{ - PARSER_OK, - PARSER_ERROR, - PARSER_HELP -}; +enum ParseRet { PARSER_OK, PARSER_ERROR, PARSER_HELP }; -int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) -{ +int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) { CLI::App app{"POET - Potsdam rEactive Transport simulator"}; @@ -182,12 +173,9 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) "Output directory of the simulation") ->required(); - try - { + try { app.parse(argc, argv); - } - catch (const CLI::ParseError &e) - { + } catch (const CLI::ParseError &e) { app.exit(e); return -1; } @@ -199,16 +187,14 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) if (params.as_qs) params.out_ext = "qs"; - if (MY_RANK == 0) - { + if (MY_RANK == 0) { // MSG("Complete results storage is " + BOOL_PRINT(simparams.store_result)); MSG("Output format/extension is " + params.out_ext); MSG("Work Package Size: " + std::to_string(params.work_package_size)); MSG("DHT is " + BOOL_PRINT(params.use_dht)); MSG("AI Surrogate is " + BOOL_PRINT(params.use_ai_surrogate)); - if (params.use_dht) - { + if (params.use_dht) { // MSG("DHT strategy is " + std::to_string(simparams.dht_strategy)); // MDL: these should be outdated (?) // MSG("DHT key default digits (ignored if 'signif_vector' is " @@ -222,8 +208,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) // MSG("DHT load file is " + chem_params.dht_file); } - if (params.use_interp) - { + if (params.use_interp) { MSG("PHT interpolation enabled: " + BOOL_PRINT(params.use_interp)); MSG("PHT interp-size = " + std::to_string(params.interp_size)); MSG("PHT interp-min = " + std::to_string(params.interp_min_entries)); @@ -251,8 +236,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) // // log before rounding? // R["dht_log"] = simparams.dht_log; - try - { + try { Rcpp::List init_params_(ReadRObj_R(init_file)); params.init_params = init_params_; @@ -266,536 +250,450 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) params.timesteps = Rcpp::as>(global_rt_setup->operator[]("timesteps")); - params.control_iteration = - Rcpp::as(global_rt_setup->operator[]("control_iteration")); - params.species_epsilon = - Rcpp::as>(global_rt_setup->operator[]("species_epsilon")); - params.penalty_iteration = - Rcpp::as(global_rt_setup->operator[]("penalty_iteration")); - params.max_penalty_iteration = - Rcpp::as(global_rt_setup->operator[]("max_penalty_iteration")); - } - catch (const std::exception &e) - { - ERRMSG("Error while parsing R scripts: " + std::string(e.what())); - return ParseRet::PARSER_ERROR; - } + params.control_interval = + Rcpp::as(global_rt_setup->operator[]("control_interval")); + params.mape_threshold = Rcpp::as>( + global_rt_setup->operator[]("mape_threshold")); - return ParseRet::PARSER_OK; -} - -// HACK: this is a step back as the order and also the count of fields is -// predefined, but it will change in the future -void call_master_iter_end(RInside &R, const Field &trans, const Field &chem) -{ - R["TMP"] = Rcpp::wrap(trans.AsVector()); - R["TMP_PROPS"] = Rcpp::wrap(trans.GetProps()); - R.parseEval(std::string("state_T <- setNames(data.frame(matrix(TMP, nrow=" + - std::to_string(trans.GetRequestedVecSize()) + - ")), TMP_PROPS)")); - - R["TMP"] = Rcpp::wrap(chem.AsVector()); - R["TMP_PROPS"] = Rcpp::wrap(chem.GetProps()); - R.parseEval(std::string("state_C <- setNames(data.frame(matrix(TMP, nrow=" + - std::to_string(chem.GetRequestedVecSize()) + - ")), TMP_PROPS)")); - R["setup"] = *global_rt_setup; - R.parseEval("setup <- master_iteration_end(setup, state_T, state_C)"); - *global_rt_setup = R["setup"]; -} - -bool checkAndRollback(ChemistryModule &chem, RuntimeParameters ¶ms, uint32_t &iter) -{ - const std::vector &latest_mape = chem.error_stats_history.back().mape; - - for (uint32_t j = 0; j < params.species_epsilon.size(); j++) - { - if (params.species_epsilon[j] < latest_mape[j] && latest_mape[j] != 0) - { - uint32_t rollback_iter = iter - (iter % params.control_iteration); - - std::cout << chem.getField().GetProps()[j] << " with a MAPE value of " << latest_mape[j] << " exceeds epsilon of " - << params.species_epsilon[j] << "! " << std::endl; - - Checkpoint_s checkpoint_read{.field = chem.getField()}; - read_checkpoint("checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read); - iter = checkpoint_read.iteration; - - return true; - } - } - MSG("All spezies are below their threshold values"); - return false; -} - -void updatePenaltyLogic(RuntimeParameters ¶ms, bool roolback_happend) -{ - if (roolback_happend) - { - params.rollback_simulation = true; - params.penalty_counter = params.penalty_iteration; - std::cout << "Penalty counter reset to: " << params.penalty_counter << std::endl; - MSG("Rollback! Penalty phase started for " + std::to_string(params.penalty_iteration) + " iterations."); - } - else - { - if (params.rollback_simulation && params.penalty_counter == 0) - { - params.rollback_simulation = false; - MSG("Penalty phase ended. Interpolation re-enabled."); - } - else if (!params.rollback_simulation) - { - params.penalty_iteration = std::min(params.penalty_iteration *= 2, params.max_penalty_iteration); - MSG("Stable surrogate phase detected. Penalty iteration doubled to " + std::to_string(params.penalty_iteration) + " iterations."); - } - } -} - -static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, - DiffusionModule &diffusion, - ChemistryModule &chem) -{ - - /* Iteration Count is dynamic, retrieving value from R (is only needed by - * master for the following loop) */ - uint32_t maxiter = params.timesteps.size(); - - if (params.print_progress) - { - chem.setProgressBarPrintout(true); - } - R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps()); - - params.next_penalty_check = params.penalty_iteration; - - /* SIMULATION LOOP */ - - double dSimTime{0}; - for (uint32_t iter = 1; iter < maxiter + 1; iter++) - { - // Penalty countdown - if (params.rollback_simulation && params.penalty_counter > 0) - { - params.penalty_counter--; - std::cout << "Penalty counter: " << params.penalty_counter << std::endl; + catch (const std::exception &e) { + ERRMSG("Error while parsing R scripts: " + std::string(e.what())); + return ParseRet::PARSER_ERROR; } - params.control_iteration_active = (iter % params.control_iteration == 0 /* && iter != 0 */); + return ParseRet::PARSER_OK; + } - double start_t = MPI_Wtime(); + // HACK: this is a step back as the order and also the count of fields is + // predefined, but it will change in the future + void call_master_iter_end(RInside & R, const Field &trans, + const Field &chem) { + R["TMP"] = Rcpp::wrap(trans.AsVector()); + R["TMP_PROPS"] = Rcpp::wrap(trans.GetProps()); + R.parseEval(std::string("state_T <- setNames(data.frame(matrix(TMP, nrow=" + + std::to_string(trans.GetRequestedVecSize()) + + ")), TMP_PROPS)")); - const double &dt = params.timesteps[iter - 1]; + R["TMP"] = Rcpp::wrap(chem.AsVector()); + R["TMP_PROPS"] = Rcpp::wrap(chem.GetProps()); + R.parseEval(std::string("state_C <- setNames(data.frame(matrix(TMP, nrow=" + + std::to_string(chem.GetRequestedVecSize()) + + ")), TMP_PROPS)")); + R["setup"] = *global_rt_setup; + R.parseEval("setup <- master_iteration_end(setup, state_T, state_C)"); + *global_rt_setup = R["setup"]; + } - std::cout << std::endl; + static Rcpp::List RunMasterLoop( + RInsidePOET & R, RuntimeParameters & params, DiffusionModule & diffusion, + ChemistryModule & chem, ControlModule & control) { - /* displaying iteration number, with C++ and R iterator */ - MSG("Going through iteration " + std::to_string(iter) + "/" + - std::to_string(maxiter)); + /* Iteration Count is dynamic, retrieving value from R (is only needed by + * master for the following loop) */ + uint32_t maxiter = params.timesteps.size(); - MSG("Current time step is " + std::to_string(dt)); + if (params.print_progress) { + chem.setProgressBarPrintout(true); + } + R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps()); - /* run transport */ - diffusion.simulate(dt); + /* SIMULATION LOOP */ - chem.runtime_params = ¶ms; + double dSimTime{0}; + for (uint32_t iter = 1; iter < maxiter + 1; iter++) { + control.updateControlIteration(iter, params.use_dht, params.use_interp); - chem.getField().update(diffusion.getField()); + double start_t = MPI_Wtime(); - // MSG("Chemistry start"); - if (params.use_ai_surrogate) - { - double ai_start_t = MPI_Wtime(); - // Save current values from the tug field as predictor for the ai step - R["TMP"] = Rcpp::wrap(chem.getField().AsVector()); - R.parseEval( - std::string("predictors <- setNames(data.frame(matrix(TMP, nrow=" + - std::to_string(chem.getField().GetRequestedVecSize()) + - ")), TMP_PROPS)")); - R.parseEval("predictors <- predictors[ai_surrogate_species]"); + const double &dt = params.timesteps[iter - 1]; - // Apply preprocessing - MSG("AI Preprocessing"); - R.parseEval("predictors_scaled <- preprocess(predictors)"); + std::cout << std::endl; - // Predict - MSG("AI Prediction"); - R.parseEval( - "aipreds_scaled <- prediction_step(model, predictors_scaled)"); + /* displaying iteration number, with C++ and R iterator */ + MSG("Going through iteration " + std::to_string(iter) + "/" + + std::to_string(maxiter)); - // Apply postprocessing - MSG("AI Postprocessing"); - R.parseEval("aipreds <- postprocess(aipreds_scaled)"); + MSG("Current time step is " + std::to_string(dt)); - // Validate prediction and write valid predictions to chem field - MSG("AI Validation"); - R.parseEval( - "validity_vector <- validate_predictions(predictors, aipreds)"); + /* run transport */ + diffusion.simulate(dt); - MSG("AI Marking accepted"); - chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector")); + chem.getField().update(diffusion.getField()); - MSG("AI TempField"); - std::vector> RTempField = - R.parseEval("set_valid_predictions(predictors,\ + // MSG("Chemistry start"); + if (params.use_ai_surrogate) { + double ai_start_t = MPI_Wtime(); + // Save current values from the tug field as predictor for the ai step + R["TMP"] = Rcpp::wrap(chem.getField().AsVector()); + R.parseEval( + std::string("predictors <- setNames(data.frame(matrix(TMP, nrow=" + + std::to_string(chem.getField().GetRequestedVecSize()) + + ")), TMP_PROPS)")); + R.parseEval("predictors <- predictors[ai_surrogate_species]"); + + // Apply preprocessing + MSG("AI Preprocessing"); + R.parseEval("predictors_scaled <- preprocess(predictors)"); + + // Predict + MSG("AI Prediction"); + R.parseEval( + "aipreds_scaled <- prediction_step(model, predictors_scaled)"); + + // Apply postprocessing + MSG("AI Postprocessing"); + R.parseEval("aipreds <- postprocess(aipreds_scaled)"); + + // Validate prediction and write valid predictions to chem field + MSG("AI Validation"); + R.parseEval( + "validity_vector <- validate_predictions(predictors, aipreds)"); + + MSG("AI Marking accepted"); + chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector")); + + MSG("AI TempField"); + std::vector> RTempField = + R.parseEval("set_valid_predictions(predictors,\ aipreds,\ validity_vector)"); - MSG("AI Set Field"); - Field predictions_field = - Field(R.parseEval("nrow(predictors)"), RTempField, - R.parseEval("colnames(predictors)")); + MSG("AI Set Field"); + Field predictions_field = + Field(R.parseEval("nrow(predictors)"), RTempField, + R.parseEval("colnames(predictors)")); - MSG("AI Update"); - chem.getField().update(predictions_field); - double ai_end_t = MPI_Wtime(); - R["ai_prediction_time"] = ai_end_t - ai_start_t; + MSG("AI Update"); + chem.getField().update(predictions_field); + double ai_end_t = MPI_Wtime(); + R["ai_prediction_time"] = ai_end_t - ai_start_t; + } + + chem.simulate(dt); + + /* AI surrogate iterative training*/ + if (params.use_ai_surrogate) { + double ai_start_t = MPI_Wtime(); + + R["TMP"] = Rcpp::wrap(chem.getField().AsVector()); + R.parseEval( + std::string("targets <- setNames(data.frame(matrix(TMP, nrow=" + + std::to_string(chem.getField().GetRequestedVecSize()) + + ")), TMP_PROPS)")); + R.parseEval("targets <- targets[ai_surrogate_species]"); + + // TODO: Check how to get the correct columns + R.parseEval("target_scaled <- preprocess(targets)"); + + MSG("AI: incremental training"); + R.parseEval("model <- training_step(model, predictors_scaled, " + "target_scaled, validity_vector)"); + double ai_end_t = MPI_Wtime(); + R["ai_training_time"] = ai_end_t - ai_start_t; + } + + // MPI_Barrier(MPI_COMM_WORLD); + double end_t = MPI_Wtime(); + dSimTime += end_t - start_t; + R["totaltime"] = dSimTime; + + // MDL master_iteration_end just writes on disk state_T and + // state_C after every iteration if the cmdline option + // --ignore-results is not given (and thus the R variable + // store_result is TRUE) + call_master_iter_end(R, diffusion.getField(), chem.getField()); + + // TODO: write checkpoint + // checkpoint struct --> field and iteration + + diffusion.getField().update(chem.getField()); + + MSG("End of *coupling* iteration " + std::to_string(iter) + "/" + + std::to_string(maxiter)); + + control.applyControlLogic(chem, iter); + // MSG(); + } // END SIMULATION LOOP + + std::cout << std::endl; + + Rcpp::List chem_profiling; + chem_profiling["simtime"] = chem.GetChemistryTime(); + chem_profiling["loop"] = chem.GetMasterLoopTime(); + chem_profiling["sequential"] = chem.GetMasterSequentialTime(); + chem_profiling["idle_master"] = chem.GetMasterIdleTime(); + chem_profiling["idle_worker"] = Rcpp::wrap(chem.GetWorkerIdleTimings()); + chem_profiling["phreeqc_time"] = Rcpp::wrap(chem.GetWorkerPhreeqcTimings()); + + Rcpp::List diffusion_profiling; + diffusion_profiling["simtime"] = diffusion.getTransportTime(); + + if (params.use_dht) { + chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits()); + chem_profiling["dht_evictions"] = + Rcpp::wrap(chem.GetWorkerDHTEvictions()); + chem_profiling["dht_get_time"] = + Rcpp::wrap(chem.GetWorkerDHTGetTimings()); + chem_profiling["dht_fill_time"] = + Rcpp::wrap(chem.GetWorkerDHTFillTimings()); } - chem.simulate(dt); - - /* AI surrogate iterative training*/ - if (params.use_ai_surrogate) - { - double ai_start_t = MPI_Wtime(); - - R["TMP"] = Rcpp::wrap(chem.getField().AsVector()); - R.parseEval( - std::string("targets <- setNames(data.frame(matrix(TMP, nrow=" + - std::to_string(chem.getField().GetRequestedVecSize()) + - ")), TMP_PROPS)")); - R.parseEval("targets <- targets[ai_surrogate_species]"); - - // TODO: Check how to get the correct columns - R.parseEval("target_scaled <- preprocess(targets)"); - - MSG("AI: incremental training"); - R.parseEval("model <- training_step(model, predictors_scaled, " - "target_scaled, validity_vector)"); - double ai_end_t = MPI_Wtime(); - R["ai_training_time"] = ai_end_t - ai_start_t; + if (params.use_interp) { + chem_profiling["interp_w"] = + Rcpp::wrap(chem.GetWorkerInterpolationWriteTimings()); + chem_profiling["interp_r"] = + Rcpp::wrap(chem.GetWorkerInterpolationReadTimings()); + chem_profiling["interp_g"] = + Rcpp::wrap(chem.GetWorkerInterpolationGatherTimings()); + chem_profiling["interp_fc"] = + Rcpp::wrap(chem.GetWorkerInterpolationFunctionCallTimings()); + chem_profiling["interp_calls"] = + Rcpp::wrap(chem.GetWorkerInterpolationCalls()); + chem_profiling["interp_cached"] = + Rcpp::wrap(chem.GetWorkerPHTCacheHits()); } - // MPI_Barrier(MPI_COMM_WORLD); - double end_t = MPI_Wtime(); - dSimTime += end_t - start_t; - R["totaltime"] = dSimTime; + Rcpp::List profiling; + profiling["simtime"] = dSimTime; + profiling["chemistry"] = chem_profiling; + profiling["diffusion"] = diffusion_profiling; - // MDL master_iteration_end just writes on disk state_T and - // state_C after every iteration if the cmdline option - // --ignore-results is not given (and thus the R variable - // store_result is TRUE) - call_master_iter_end(R, diffusion.getField(), chem.getField()); + chem.MasterLoopBreak(); - // TODO: write checkpoint - // checkpoint struct --> field and iteration - - diffusion.getField().update(chem.getField()); - - MSG("End of *coupling* iteration " + std::to_string(iter) + "/" + - std::to_string(maxiter)); - - if (iter % params.control_iteration == 0) - { - writeStatsToCSV(chem.error_stats_history, chem.getField().GetProps(), "stats_overview"); - write_checkpoint("checkpoint" + std::to_string(iter) + ".hdf5", - {.field = chem.getField(), .iteration = iter}); - } - - if (iter == params.next_penalty_check) - { - bool roolback_happend = checkAndRollback(chem, params, iter); - updatePenaltyLogic(params, roolback_happend); - - params.next_penalty_check = iter + params.penalty_iteration; - } - - // MSG(); - } // END SIMULATION LOOP - - std::cout << std::endl; - - Rcpp::List chem_profiling; - chem_profiling["simtime"] = chem.GetChemistryTime(); - chem_profiling["loop"] = chem.GetMasterLoopTime(); - chem_profiling["sequential"] = chem.GetMasterSequentialTime(); - chem_profiling["idle_master"] = chem.GetMasterIdleTime(); - chem_profiling["idle_worker"] = Rcpp::wrap(chem.GetWorkerIdleTimings()); - chem_profiling["phreeqc_time"] = Rcpp::wrap(chem.GetWorkerPhreeqcTimings()); - - Rcpp::List diffusion_profiling; - diffusion_profiling["simtime"] = diffusion.getTransportTime(); - - if (params.use_dht) - { - chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits()); - chem_profiling["dht_evictions"] = Rcpp::wrap(chem.GetWorkerDHTEvictions()); - chem_profiling["dht_get_time"] = Rcpp::wrap(chem.GetWorkerDHTGetTimings()); - chem_profiling["dht_fill_time"] = - Rcpp::wrap(chem.GetWorkerDHTFillTimings()); + return profiling; } - if (params.use_interp) - { - chem_profiling["interp_w"] = - Rcpp::wrap(chem.GetWorkerInterpolationWriteTimings()); - chem_profiling["interp_r"] = - Rcpp::wrap(chem.GetWorkerInterpolationReadTimings()); - chem_profiling["interp_g"] = - Rcpp::wrap(chem.GetWorkerInterpolationGatherTimings()); - chem_profiling["interp_fc"] = - Rcpp::wrap(chem.GetWorkerInterpolationFunctionCallTimings()); - chem_profiling["interp_calls"] = - Rcpp::wrap(chem.GetWorkerInterpolationCalls()); - chem_profiling["interp_cached"] = Rcpp::wrap(chem.GetWorkerPHTCacheHits()); - } + std::vector getSpeciesNames(const Field &&field, int root, + MPI_Comm comm) { + std::uint32_t n_elements; + std::uint32_t n_string_size; - Rcpp::List profiling; - profiling["simtime"] = dSimTime; - profiling["chemistry"] = chem_profiling; - profiling["diffusion"] = diffusion_profiling; + int rank; + MPI_Comm_rank(comm, &rank); - chem.MasterLoopBreak(); + const bool is_master = root == rank; - return profiling; -} + // first, the master sends all the species names iterative + if (is_master) { + n_elements = field.GetProps().size(); + MPI_Bcast(&n_elements, 1, MPI_UINT32_T, root, MPI_COMM_WORLD); -std::vector getSpeciesNames(const Field &&field, int root, - MPI_Comm comm) -{ - std::uint32_t n_elements; - std::uint32_t n_string_size; + for (std::uint32_t i = 0; i < n_elements; i++) { + n_string_size = field.GetProps()[i].size(); + MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD); + MPI_Bcast(const_cast(field.GetProps()[i].c_str()), + n_string_size, MPI_CHAR, root, MPI_COMM_WORLD); + } - int rank; - MPI_Comm_rank(comm, &rank); + return field.GetProps(); + } - const bool is_master = root == rank; + // now all the worker stuff + MPI_Bcast(&n_elements, 1, MPI_UINT32_T, root, comm); - // first, the master sends all the species names iterative - if (is_master) - { - n_elements = field.GetProps().size(); - MPI_Bcast(&n_elements, 1, MPI_UINT32_T, root, MPI_COMM_WORLD); + std::vector species_names_out(n_elements); - for (std::uint32_t i = 0; i < n_elements; i++) - { - n_string_size = field.GetProps()[i].size(); + for (std::uint32_t i = 0; i < n_elements; i++) { MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD); - MPI_Bcast(const_cast(field.GetProps()[i].c_str()), n_string_size, - MPI_CHAR, root, MPI_COMM_WORLD); + + char recv_buf[n_string_size]; + + MPI_Bcast(recv_buf, n_string_size, MPI_CHAR, root, MPI_COMM_WORLD); + + species_names_out[i] = std::string(recv_buf, n_string_size); } - return field.GetProps(); + return species_names_out; } - // now all the worker stuff - MPI_Bcast(&n_elements, 1, MPI_UINT32_T, root, comm); + std::array getBaseTotals(Field && field, int root, MPI_Comm comm) { + std::array base_totals; - std::vector species_names_out(n_elements); + int rank; + MPI_Comm_rank(comm, &rank); - for (std::uint32_t i = 0; i < n_elements; i++) - { - MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD); + const bool is_master = root == rank; - char recv_buf[n_string_size]; + if (is_master) { + const auto h_col = field["H"]; + const auto o_col = field["O"]; - MPI_Bcast(recv_buf, n_string_size, MPI_CHAR, root, MPI_COMM_WORLD); + base_totals[0] = *std::min_element(h_col.begin(), h_col.end()); + base_totals[1] = *std::min_element(o_col.begin(), o_col.end()); + MPI_Bcast(base_totals.data(), 2, MPI_DOUBLE, root, MPI_COMM_WORLD); + return base_totals; + } - species_names_out[i] = std::string(recv_buf, n_string_size); - } + MPI_Bcast(base_totals.data(), 2, MPI_DOUBLE, root, comm); - return species_names_out; -} - -std::array getBaseTotals(Field &&field, int root, MPI_Comm comm) -{ - std::array base_totals; - - int rank; - MPI_Comm_rank(comm, &rank); - - const bool is_master = root == rank; - - if (is_master) - { - const auto h_col = field["H"]; - const auto o_col = field["O"]; - - base_totals[0] = *std::min_element(h_col.begin(), h_col.end()); - base_totals[1] = *std::min_element(o_col.begin(), o_col.end()); - MPI_Bcast(base_totals.data(), 2, MPI_DOUBLE, root, MPI_COMM_WORLD); return base_totals; } - MPI_Bcast(base_totals.data(), 2, MPI_DOUBLE, root, comm); + bool getHasID(Field && field, int root, MPI_Comm comm) { + bool has_id; - return base_totals; -} + int rank; + MPI_Comm_rank(comm, &rank); -bool getHasID(Field &&field, int root, MPI_Comm comm) -{ - bool has_id; + const bool is_master = root == rank; - int rank; - MPI_Comm_rank(comm, &rank); + if (is_master) { + const auto ID_field = field["ID"]; - const bool is_master = root == rank; + std::set unique_IDs(ID_field.begin(), ID_field.end()); - if (is_master) - { - const auto ID_field = field["ID"]; + has_id = unique_IDs.size() > 1; - std::set unique_IDs(ID_field.begin(), ID_field.end()); + MPI_Bcast(&has_id, 1, MPI_C_BOOL, root, MPI_COMM_WORLD); - has_id = unique_IDs.size() > 1; + return has_id; + } - MPI_Bcast(&has_id, 1, MPI_C_BOOL, root, MPI_COMM_WORLD); + MPI_Bcast(&has_id, 1, MPI_C_BOOL, root, comm); return has_id; } - MPI_Bcast(&has_id, 1, MPI_C_BOOL, root, comm); + int main(int argc, char *argv[]) { + int world_size; - return has_id; -} + MPI_Init(&argc, &argv); -int main(int argc, char *argv[]) -{ - int world_size; - - MPI_Init(&argc, &argv); - - { - MPI_Comm_size(MPI_COMM_WORLD, &world_size); - MPI_Comm_rank(MPI_COMM_WORLD, &MY_RANK); - - RInsidePOET &R = RInsidePOET::getInstance(); - - if (MY_RANK == 0) { - MSG("Running POET version " + std::string(poet_version)); - } + MPI_Comm_size(MPI_COMM_WORLD, &world_size); + MPI_Comm_rank(MPI_COMM_WORLD, &MY_RANK); - init_global_functions(R); + RInsidePOET &R = RInsidePOET::getInstance(); - RuntimeParameters run_params; - - if (parseInitValues(argc, argv, run_params) != 0) - { - MPI_Finalize(); - return 0; - } - - // switch (parseInitValues(argc, argv, run_params)) { - // case ParseRet::PARSER_ERROR: - // case ParseRet::PARSER_HELP: - // MPI_Finalize(); - // return 0; - // case ParseRet::PARSER_OK: - // break; - // } - - InitialList init_list(R); - init_list.importList(run_params.init_params, MY_RANK != 0); - - MSG("RInside initialized on process " + std::to_string(MY_RANK)); - - std::cout << std::flush; - - MPI_Barrier(MPI_COMM_WORLD); - - ChemistryModule chemistry(run_params.work_package_size, - init_list.getChemistryInit(), MPI_COMM_WORLD); - - const ChemistryModule::SurrogateSetup surr_setup = { - getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), - getBaseTotals(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), - getHasID(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), - run_params.use_dht, - run_params.dht_size, - run_params.dht_snaps, - run_params.out_dir, - run_params.use_interp, - run_params.interp_bucket_entries, - run_params.interp_size, - run_params.interp_min_entries, - run_params.use_ai_surrogate}; - - chemistry.masterEnableSurrogates(surr_setup); - - if (MY_RANK > 0) - { - chemistry.WorkerLoop(); - } - else - { - // R.parseEvalQ("mysetup <- setup"); - // // if (MY_RANK == 0) { // get timestep vector from - // // grid_init function ... // - - *global_rt_setup = master_init_R(*global_rt_setup, run_params.out_dir, - init_list.getInitialGrid().asSEXP()); - - // MDL: store all parameters - // MSG("Calling R Function to store calling parameters"); - // R.parseEvalQ("StoreSetup(setup=mysetup)"); - R["out_ext"] = run_params.out_ext; - R["out_dir"] = run_params.out_dir; - - if (run_params.use_ai_surrogate) - { - /* Incorporate ai surrogate from R */ - R.parseEvalQ(ai_surrogate_r_library); - /* Use dht species for model input and output */ - R["ai_surrogate_species"] = - init_list.getChemistryInit().dht_species.getNames(); - - const std::string ai_surrogate_input_script = - init_list.getChemistryInit().ai_surrogate_input_script; - - MSG("AI: sourcing user-provided script"); - R.parseEvalQ(ai_surrogate_input_script); - - MSG("AI: initialize AI model"); - R.parseEval("model <- initiate_model()"); - R.parseEval("gpu_info()"); + if (MY_RANK == 0) { + MSG("Running POET version " + std::string(poet_version)); } - MSG("Init done on process with rank " + std::to_string(MY_RANK)); + init_global_functions(R); - // MPI_Barrier(MPI_COMM_WORLD); + RuntimeParameters run_params; - DiffusionModule diffusion(init_list.getDiffusionInit(), - init_list.getInitialGrid()); + if (parseInitValues(argc, argv, run_params) != 0) { + MPI_Finalize(); + return 0; + } - chemistry.masterSetField(init_list.getInitialGrid()); + // switch (parseInitValues(argc, argv, run_params)) { + // case ParseRet::PARSER_ERROR: + // case ParseRet::PARSER_HELP: + // MPI_Finalize(); + // return 0; + // case ParseRet::PARSER_OK: + // break; + // } - Rcpp::List profiling = RunMasterLoop(R, run_params, diffusion, chemistry); + InitialList init_list(R); + init_list.importList(run_params.init_params, MY_RANK != 0); - MSG("finished simulation loop"); + MSG("RInside initialized on process " + std::to_string(MY_RANK)); - R["profiling"] = profiling; - R["setup"] = *global_rt_setup; - R["setup$out_ext"] = run_params.out_ext; + std::cout << std::flush; - std::string r_vis_code; - r_vis_code = "SaveRObj(x = profiling, path = paste0(out_dir, " - "'/timings.', setup$out_ext));"; - R.parseEval(r_vis_code); + MPI_Barrier(MPI_COMM_WORLD); - MSG("Done! Results are stored as R objects into <" + run_params.out_dir + - "/timings." + run_params.out_ext); + ChemistryModule chemistry(run_params.work_package_size, + init_list.getChemistryInit(), MPI_COMM_WORLD); + ControlModule control; + chemistry.SetControlModule(&control); + control.setChemistryModule(&chemistry); + + const ChemistryModule::SurrogateSetup surr_setup = { + getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), + getBaseTotals(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), + getHasID(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), + run_params.use_dht, + run_params.dht_size, + run_params.dht_snaps, + run_params.out_dir, + run_params.use_interp, + run_params.interp_bucket_entries, + run_params.interp_size, + run_params.interp_min_entries, + run_params.use_ai_surrogate}; + + chemistry.masterEnableSurrogates(surr_setup); + + const ControlModule::ControlSetup ctrl_setup = { + run_params.out_dir, // added + run_params.checkpoint_interval, run_params.control_interval, + getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), + run_params.mape_threshold}; + + control.enableControlLogic(ctrl_setup); + + if (MY_RANK > 0) { + chemistry.WorkerLoop(); + } else { + // R.parseEvalQ("mysetup <- setup"); + // // if (MY_RANK == 0) { // get timestep vector from + // // grid_init function ... // + + *global_rt_setup = master_init_R(*global_rt_setup, run_params.out_dir, + init_list.getInitialGrid().asSEXP()); + + // MDL: store all parameters + // MSG("Calling R Function to store calling parameters"); + // R.parseEvalQ("StoreSetup(setup=mysetup)"); + R["out_ext"] = run_params.out_ext; + R["out_dir"] = run_params.out_dir; + + if (run_params.use_ai_surrogate) { + /* Incorporate ai surrogate from R */ + R.parseEvalQ(ai_surrogate_r_library); + /* Use dht species for model input and output */ + R["ai_surrogate_species"] = + init_list.getChemistryInit().dht_species.getNames(); + + const std::string ai_surrogate_input_script = + init_list.getChemistryInit().ai_surrogate_input_script; + + MSG("AI: sourcing user-provided script"); + R.parseEvalQ(ai_surrogate_input_script); + + MSG("AI: initialize AI model"); + R.parseEval("model <- initiate_model()"); + R.parseEval("gpu_info()"); + } + + MSG("Init done on process with rank " + std::to_string(MY_RANK)); + + // MPI_Barrier(MPI_COMM_WORLD); + + DiffusionModule diffusion(init_list.getDiffusionInit(), + init_list.getInitialGrid()); + + chemistry.masterSetField(init_list.getInitialGrid()); + + Rcpp::List profiling = + RunMasterLoop(R, run_params, diffusion, chemistry, control); + + MSG("finished simulation loop"); + + R["profiling"] = profiling; + R["setup"] = *global_rt_setup; + R["setup$out_ext"] = run_params.out_ext; + + std::string r_vis_code; + r_vis_code = "SaveRObj(x = profiling, path = paste0(out_dir, " + "'/timings.', setup$out_ext));"; + R.parseEval(r_vis_code); + + MSG("Done! Results are stored as R objects into <" + + run_params.out_dir + "/timings." + run_params.out_ext); + } } + + MSG("finished, cleanup of process " + std::to_string(MY_RANK)); + + MPI_Finalize(); + + if (MY_RANK == 0) { + MSG("done, bye!"); + } + + exit(0); } - - MSG("finished, cleanup of process " + std::to_string(MY_RANK)); - - MPI_Finalize(); - - if (MY_RANK == 0) - { - MSG("done, bye!"); - } - - exit(0); -}