diff --git a/include/poet/DHT_Wrapper.hpp b/include/poet/DHT_Wrapper.hpp index a74db4f2d..115d5e769 100644 --- a/include/poet/DHT_Wrapper.hpp +++ b/include/poet/DHT_Wrapper.hpp @@ -1,4 +1,4 @@ -// Time-stamp: "Last modified 2023-08-15 14:57:51 mluebke" +// Time-stamp: "Last modified 2023-09-08 14:43:02 mluebke" /* ** Copyright (C) 2018-2021 Alexander Lindemann, Max Luebke (University of @@ -89,7 +89,7 @@ public: const std::vector &key_indices, const std::vector &output_names, const ChemistryParams::Chem_Hook_Functions &hooks, - uint32_t data_count); + uint32_t data_count, bool with_interp); /** * @brief Destroy the dht wrapper object * @@ -240,7 +240,14 @@ private: const std::vector &new_results); std::vector - inputAndRatesToOutput(const std::vector &dht_data); + inputAndRatesToOutput(const std::vector &dht_data, + const std::vector &input_values); + + std::vector outputToRates(const std::vector &old_results, + const std::vector &new_results); + + std::vector ratesToOutput(const std::vector &dht_data, + const std::vector &input_values); uint32_t dht_hits = 0; uint32_t dht_evictions = 0; @@ -254,6 +261,7 @@ private: const std::vector &output_names; const ChemistryParams::Chem_Hook_Functions &hooks; + const bool with_interp; DHT_ResultObject dht_results; diff --git a/src/ChemistryModule/ChemistryModule.cpp b/src/ChemistryModule/ChemistryModule.cpp index 3060e8617..a1b1a8837 100644 --- a/src/ChemistryModule/ChemistryModule.cpp +++ b/src/ChemistryModule/ChemistryModule.cpp @@ -378,9 +378,9 @@ void poet::ChemistryModule::initializeDHT( const std::uint64_t dht_size = size_mb * MB_FACTOR; - this->dht = - new DHT_Wrapper(dht_comm, dht_size, map_copy, key_indices, - this->prop_names, params.hooks, this->prop_count); + this->dht = new DHT_Wrapper(dht_comm, dht_size, map_copy, key_indices, + this->prop_names, params.hooks, + this->prop_count, params.use_interp); this->dht->setBaseTotals(base_totals.at(0), base_totals.at(1)); } } diff --git a/src/ChemistryModule/SurrogateModels/DHT_Wrapper.cpp b/src/ChemistryModule/SurrogateModels/DHT_Wrapper.cpp index d1ee112c2..af5fbf12e 100644 --- a/src/ChemistryModule/SurrogateModels/DHT_Wrapper.cpp +++ b/src/ChemistryModule/SurrogateModels/DHT_Wrapper.cpp @@ -1,4 +1,4 @@ -// Time-stamp: "Last modified 2023-08-16 16:44:17 mluebke" +// Time-stamp: "Last modified 2023-09-08 22:09:03 mluebke" /* ** Copyright (C) 2018-2021 Alexander Lindemann, Max Luebke (University of @@ -48,15 +48,17 @@ DHT_Wrapper::DHT_Wrapper(MPI_Comm dht_comm, std::uint64_t dht_size, const std::vector &key_indices, const std::vector &_output_names, const ChemistryParams::Chem_Hook_Functions &_hooks, - uint32_t data_count) + uint32_t data_count, bool _with_interp) : key_count(key_indices.size()), data_count(data_count), input_key_elements(key_indices), communicator(dht_comm), - key_species(key_species), output_names(_output_names), hooks(_hooks) { + key_species(key_species), output_names(_output_names), hooks(_hooks), + with_interp(_with_interp) { // initialize DHT object // key size = count of key elements + timestep uint32_t key_size = (key_count + 1) * sizeof(Lookup_Keyelement); uint32_t data_size = - (data_count + input_key_elements.size()) * sizeof(double); + (data_count + (with_interp ? input_key_elements.size() : 0)) * + sizeof(double); uint32_t buckets_per_process = static_cast(dht_size / (data_size + key_size)); dht_object = DHT_create(dht_comm, buckets_per_process, data_size, key_size, @@ -93,8 +95,8 @@ auto DHT_Wrapper::checkDHT(WorkPackage &work_package) const auto length = work_package.size; - std::vector bucket_writer(this->data_count + - input_key_elements.size()); + std::vector bucket_writer( + this->data_count + (with_interp ? input_key_elements.size() : 0)); // loop over every grid cell contained in work package for (int i = 0; i < length; i++) { @@ -107,7 +109,10 @@ auto DHT_Wrapper::checkDHT(WorkPackage &work_package) switch (res) { case DHT_SUCCESS: - work_package.output[i] = inputAndRatesToOutput(bucket_writer); + work_package.output[i] = + (with_interp + ? inputAndRatesToOutput(bucket_writer, work_package.input[i]) + : bucket_writer); work_package.mapping[i] = CHEM_DHT; this->dht_hits++; break; @@ -145,7 +150,10 @@ void DHT_Wrapper::fillDHT(const WorkPackage &work_package) { uint32_t proc, index; auto &key = dht_results.keys[i]; const auto data = - outputToInputAndRates(work_package.input[i], work_package.output[i]); + (with_interp + ? outputToInputAndRates(work_package.input[i], + work_package.output[i]) + : work_package.output[i]); // void *data = (void *)&(work_package[i * this->data_count]); // fuzz data (round, logarithm etc.) @@ -165,7 +173,7 @@ void DHT_Wrapper::fillDHT(const WorkPackage &work_package) { } } -std::vector +inline std::vector DHT_Wrapper::outputToInputAndRates(const std::vector &old_results, const std::vector &new_results) { const int prefix_size = this->input_key_elements.size(); @@ -183,11 +191,12 @@ DHT_Wrapper::outputToInputAndRates(const std::vector &old_results, return output; } -std::vector -DHT_Wrapper::inputAndRatesToOutput(const std::vector &dht_data) { +inline std::vector +DHT_Wrapper::inputAndRatesToOutput(const std::vector &dht_data, + const std::vector &input_values) { const int prefix_size = this->input_key_elements.size(); - std::vector output{dht_data.begin() + prefix_size, dht_data.end()}; + std::vector output(input_values); for (int i = 0; i < prefix_size; i++) { const int data_elem_i = input_key_elements[i]; @@ -197,6 +206,30 @@ DHT_Wrapper::inputAndRatesToOutput(const std::vector &dht_data) { return output; } +inline std::vector +DHT_Wrapper::outputToRates(const std::vector &old_results, + const std::vector &new_results) { + std::vector output(new_results); + + for (const auto &data_elem_i : input_key_elements) { + output[data_elem_i] -= old_results[data_elem_i]; + } + + return output; +} + +inline std::vector +DHT_Wrapper::ratesToOutput(const std::vector &dht_data, + const std::vector &input_values) { + std::vector output(input_values); + + for (const auto &data_elem_i : input_key_elements) { + output[data_elem_i] += dht_data[data_elem_i]; + } + + return output; +} + // void DHT_Wrapper::resultsToWP(std::vector &work_package) { // for (int i = 0; i < dht_results.length; i++) { // if (!dht_results.needPhreeqc[i]) {