diff --git a/src/Chemistry/ChemistryModule.cpp b/src/Chemistry/ChemistryModule.cpp index dc2d430f1..a6ab03c36 100644 --- a/src/Chemistry/ChemistryModule.cpp +++ b/src/Chemistry/ChemistryModule.cpp @@ -309,9 +309,10 @@ void poet::ChemistryModule::initializeInterp( map_copy = this->dht->getKeySpecies(); for (auto i = 0; i < map_copy.size(); i++) { const std::uint32_t signif = - static_cast(map_copy[i]) - (map_copy[i] > InterpolationModule::COARSE_DIFF - ? InterpolationModule::COARSE_DIFF - : 0); + static_cast(map_copy[i]) - + (map_copy[i] > InterpolationModule::COARSE_DIFF + ? InterpolationModule::COARSE_DIFF + : 0); map_copy[i] = signif; } } @@ -368,7 +369,8 @@ void poet::ChemistryModule::unshuffleField(const std::vector &in_buffer, } } } - -void poet::ChemistryModule::set_ai_surrogate_validity_vector(std::vector r_vector) { + +void poet::ChemistryModule::set_ai_surrogate_validity_vector( + std::vector r_vector) { this->ai_surrogate_validity_vector = r_vector; } diff --git a/src/Chemistry/ChemistryModule.hpp b/src/Chemistry/ChemistryModule.hpp index d2a9fca0c..f0c164754 100644 --- a/src/Chemistry/ChemistryModule.hpp +++ b/src/Chemistry/ChemistryModule.hpp @@ -76,6 +76,7 @@ public: struct SurrogateSetup { std::vector prop_names; + std::array base_totals; bool dht_enabled; std::uint32_t dht_size_mb; @@ -96,6 +97,8 @@ public: this->interp_enabled = setup.interp_enabled; this->ai_surrogate_enabled = setup.ai_surrogate_enabled; + this->base_totals = setup.base_totals; + if (this->dht_enabled || this->interp_enabled) { this->initializeDHT(setup.dht_size_mb, this->params.dht_species); } @@ -223,8 +226,8 @@ public: }; /** - * **Master only** Set the ai surrogate validity vector from R - */ + * **Master only** Set the ai surrogate validity vector from R + */ void set_ai_surrogate_validity_vector(std::vector r_vector); std::vector GetWorkerInterpolationCalls() const; diff --git a/src/poet.cpp b/src/poet.cpp index 3da740096..df51ce31c 100644 --- a/src/poet.cpp +++ b/src/poet.cpp @@ -34,6 +34,8 @@ #include #include #include +#include +#include #include #include #include @@ -150,7 +152,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) { app.add_flag("--qs", params.as_qs, "Save output as .qs file instead of default .qs2"); - + std::string init_file; std::string runtime_file; @@ -178,8 +180,10 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) { // set the output extension params.out_ext = "qs2"; - if (params.as_rds) params.out_ext = "rds"; - if (params.as_qs) params.out_ext = "qs"; + if (params.as_rds) + params.out_ext = "rds"; + if (params.as_qs) + params.out_ext = "qs"; if (MY_RANK == 0) { // MSG("Complete results storage is " + BOOL_PRINT(simparams.store_result)); @@ -293,7 +297,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms, const double &dt = params.timesteps[iter - 1]; std::cout << std::endl; - + /* displaying iteration number, with C++ and R iterator */ MSG("Going through iteration " + std::to_string(iter) + "/" + std::to_string(maxiter)); @@ -396,7 +400,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms, } // END SIMULATION LOOP std::cout << std::endl; - + Rcpp::List chem_profiling; chem_profiling["simtime"] = chem.GetChemistryTime(); chem_profiling["loop"] = chem.GetMasterLoopTime(); @@ -483,6 +487,29 @@ std::vector getSpeciesNames(const Field &&field, int root, return species_names_out; } +std::array getBaseTotals(Field &&field, int root, MPI_Comm comm) { + std::array base_totals; + + int rank; + MPI_Comm_rank(comm, &rank); + + const bool is_master = root == rank; + + if (is_master) { + const auto h_col = field["H"]; + const auto o_col = field["O"]; + + base_totals[0] = *std::min_element(h_col.begin(), h_col.end()); + base_totals[1] = *std::min_element(o_col.begin(), o_col.end()); + MPI_Bcast(base_totals.data(), 2, MPI_DOUBLE, root, MPI_COMM_WORLD); + return base_totals; + } + + MPI_Bcast(base_totals.data(), 2, MPI_DOUBLE, root, comm); + + return base_totals; +} + int main(int argc, char *argv[]) { int world_size; @@ -529,8 +556,8 @@ int main(int argc, char *argv[]) { init_list.getChemistryInit(), MPI_COMM_WORLD); const ChemistryModule::SurrogateSetup surr_setup = { - getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), + getBaseTotals(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), run_params.use_dht, run_params.dht_size, run_params.use_interp,