diff --git a/src/Chemistry/ChemistryModule.cpp b/src/Chemistry/ChemistryModule.cpp index cbfe0745c..32c14be90 100644 --- a/src/Chemistry/ChemistryModule.cpp +++ b/src/Chemistry/ChemistryModule.cpp @@ -185,7 +185,8 @@ poet::ChemistryModule::~ChemistryModule() { } void poet::ChemistryModule::initializeDHT( - uint32_t size_mb, const NamedVector &key_species) { + uint32_t size_mb, const NamedVector &key_species, + bool has_het_ids) { constexpr uint32_t MB_FACTOR = 1E6; MPI_Comm dht_comm; @@ -217,7 +218,7 @@ void poet::ChemistryModule::initializeDHT( this->dht = new DHT_Wrapper(dht_comm, dht_size, map_copy, key_indices, this->prop_names, params.hooks, - this->prop_count, interp_enabled); + this->prop_count, interp_enabled, has_het_ids); this->dht->setBaseTotals(base_totals.at(0), base_totals.at(1)); } } diff --git a/src/Chemistry/ChemistryModule.hpp b/src/Chemistry/ChemistryModule.hpp index f0c164754..697b50744 100644 --- a/src/Chemistry/ChemistryModule.hpp +++ b/src/Chemistry/ChemistryModule.hpp @@ -77,6 +77,7 @@ public: struct SurrogateSetup { std::vector prop_names; std::array base_totals; + bool has_het_ids; bool dht_enabled; std::uint32_t dht_size_mb; @@ -100,7 +101,8 @@ public: this->base_totals = setup.base_totals; if (this->dht_enabled || this->interp_enabled) { - this->initializeDHT(setup.dht_size_mb, this->params.dht_species); + this->initializeDHT(setup.dht_size_mb, this->params.dht_species, + setup.has_het_ids); } if (this->interp_enabled) { @@ -243,7 +245,8 @@ public: protected: void initializeDHT(uint32_t size_mb, - const NamedVector &key_species); + const NamedVector &key_species, + bool has_het_ids); void setDHTSnapshots(int type, const std::string &out_dir); void setDHTReadFile(const std::string &input_file); diff --git a/src/Chemistry/SurrogateModels/DHT_Wrapper.cpp b/src/Chemistry/SurrogateModels/DHT_Wrapper.cpp index 095f83cfa..0d9d2d2f6 100644 --- a/src/Chemistry/SurrogateModels/DHT_Wrapper.cpp +++ b/src/Chemistry/SurrogateModels/DHT_Wrapper.cpp @@ -43,11 +43,12 @@ DHT_Wrapper::DHT_Wrapper(MPI_Comm dht_comm, std::uint64_t dht_size, const std::vector &key_indices, const std::vector &_output_names, const InitialList::ChemistryHookFunctions &_hooks, - uint32_t data_count, bool _with_interp) + uint32_t data_count, bool _with_interp, + bool _has_het_ids) : key_count(key_indices.size()), data_count(data_count), input_key_elements(key_indices), communicator(dht_comm), key_species(key_species), output_names(_output_names), hooks(_hooks), - with_interp(_with_interp) { + with_interp(_with_interp), has_het_ids(_has_het_ids) { // initialize DHT object // key size = count of key elements + timestep uint32_t key_size = (key_count + 1) * sizeof(Lookup_Keyelement); @@ -270,7 +271,7 @@ LookupKey DHT_Wrapper::fuzzForDHT_R(const std::vector &cell, const std::vector eval_vec = Rcpp::as>(hooks.dht_fuzz(input_nv)); assert(eval_vec.size() == this->key_count); - LookupKey vecFuzz(this->key_count + 1, {.0}); + LookupKey vecFuzz(this->key_count + 1 + has_het_ids, {.0}); DHT_Rounder rounder; @@ -290,6 +291,9 @@ LookupKey DHT_Wrapper::fuzzForDHT_R(const std::vector &cell, } // add timestep to the end of the key as double value vecFuzz[this->key_count].fp_element = dt; + if (has_het_ids) { + vecFuzz[this->key_count + 1].fp_element = cell[0]; + } return vecFuzz; } @@ -297,7 +301,7 @@ LookupKey DHT_Wrapper::fuzzForDHT_R(const std::vector &cell, LookupKey DHT_Wrapper::fuzzForDHT(const std::vector &cell, double dt) { const auto c_zero_val = std::pow(10, AQUEOUS_EXP); - LookupKey vecFuzz(this->key_count + 1, {.0}); + LookupKey vecFuzz(this->key_count + 1 + has_het_ids, {.0}); DHT_Rounder rounder; int totals_i = 0; @@ -323,6 +327,9 @@ LookupKey DHT_Wrapper::fuzzForDHT(const std::vector &cell, double dt) { } // add timestep to the end of the key as double value vecFuzz[this->key_count].fp_element = dt; + if (has_het_ids) { + vecFuzz[this->key_count + 1].fp_element = cell[0]; + } return vecFuzz; } diff --git a/src/Chemistry/SurrogateModels/DHT_Wrapper.hpp b/src/Chemistry/SurrogateModels/DHT_Wrapper.hpp index e0b1c022d..9449692b4 100644 --- a/src/Chemistry/SurrogateModels/DHT_Wrapper.hpp +++ b/src/Chemistry/SurrogateModels/DHT_Wrapper.hpp @@ -87,7 +87,7 @@ public: const std::vector &key_indices, const std::vector &output_names, const InitialList::ChemistryHookFunctions &hooks, - uint32_t data_count, bool with_interp); + uint32_t data_count, bool with_interp, bool has_het_ids); /** * @brief Destroy the dht wrapper object * @@ -264,6 +264,7 @@ private: DHT_ResultObject dht_results; std::array base_totals{0}; + bool has_het_ids{false}; }; } // namespace poet diff --git a/src/poet.cpp b/src/poet.cpp index df51ce31c..c7864b4b6 100644 --- a/src/poet.cpp +++ b/src/poet.cpp @@ -40,6 +40,7 @@ #include #include #include +#include #include #include @@ -510,6 +511,31 @@ std::array getBaseTotals(Field &&field, int root, MPI_Comm comm) { return base_totals; } +bool getHasID(Field &&field, int root, MPI_Comm comm) { + bool has_id; + + int rank; + MPI_Comm_rank(comm, &rank); + + const bool is_master = root == rank; + + if (is_master) { + const auto ID_field = field["ID"]; + + std::set unique_IDs(ID_field.begin(), ID_field.end()); + + has_id = unique_IDs.size() > 1; + + MPI_Bcast(&has_id, 1, MPI_C_BOOL, root, MPI_COMM_WORLD); + + return has_id; + } + + MPI_Bcast(&has_id, 1, MPI_C_BOOL, root, comm); + + return has_id; +} + int main(int argc, char *argv[]) { int world_size; @@ -558,6 +584,7 @@ int main(int argc, char *argv[]) { const ChemistryModule::SurrogateSetup surr_setup = { getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), getBaseTotals(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), + getHasID(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), run_params.use_dht, run_params.dht_size, run_params.use_interp,