From 7bb103fade765a6a5ae07d768892a5143c408627 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20L=C3=BCbke?= Date: Fri, 13 Dec 2024 08:21:25 +0100 Subject: [PATCH] feat: implement caching for interpolation calculations and add NaN handling --- src/Chemistry/ChemistryModule.cpp | 1 + .../SurrogateModels/Interpolation.hpp | 2 ++ .../SurrogateModels/InterpolationModule.cpp | 18 +++++++++++++++++- src/Chemistry/SurrogateModels/LookupKey.hpp | 12 ++++++++++++ src/Chemistry/SurrogateModels/Rounding.hpp | 5 +++++ 5 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/Chemistry/ChemistryModule.cpp b/src/Chemistry/ChemistryModule.cpp index a6ab03c36..ed0b17d9a 100644 --- a/src/Chemistry/ChemistryModule.cpp +++ b/src/Chemistry/ChemistryModule.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include diff --git a/src/Chemistry/SurrogateModels/Interpolation.hpp b/src/Chemistry/SurrogateModels/Interpolation.hpp index b6b9a69f5..efa79e4f0 100644 --- a/src/Chemistry/SurrogateModels/Interpolation.hpp +++ b/src/Chemistry/SurrogateModels/Interpolation.hpp @@ -261,6 +261,8 @@ private: const InitialList::ChemistryHookFunctions &hooks; const std::vector &out_names; const std::vector dht_names; + + std::unordered_map> to_calc_cache; }; } // namespace poet diff --git a/src/Chemistry/SurrogateModels/InterpolationModule.cpp b/src/Chemistry/SurrogateModels/InterpolationModule.cpp index a2871bb8a..4185e0292 100644 --- a/src/Chemistry/SurrogateModels/InterpolationModule.cpp +++ b/src/Chemistry/SurrogateModels/InterpolationModule.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -116,10 +117,25 @@ void InterpolationModule::tryInterpolation(WorkPackage &work_package) { this->pht->incrementReadCounter(roundKey(rounded_key)); #endif + const int cell_id = static_cast(work_package.input[wp_i][0]); + + if (!to_calc_cache.contains(cell_id)) { + std::vector to_calc = dht_instance.getKeyElements(); + std::vector keep_indices; + + for (std::size_t i = 0; i < to_calc.size(); i++) { + if (!std::isnan(work_package.input[wp_i][to_calc[i]])) { + keep_indices.push_back(to_calc[i]); + } + } + + to_calc_cache[cell_id] = keep_indices; + } + double start_fc = MPI_Wtime(); work_package.output[wp_i] = - f_interpolate(dht_instance.getKeyElements(), work_package.input[wp_i], + f_interpolate(to_calc_cache[cell_id], work_package.input[wp_i], pht_result.in_values, pht_result.out_values); if (hooks.interp_post.isValid()) { diff --git a/src/Chemistry/SurrogateModels/LookupKey.hpp b/src/Chemistry/SurrogateModels/LookupKey.hpp index e6dc7e697..6e7b0681c 100644 --- a/src/Chemistry/SurrogateModels/LookupKey.hpp +++ b/src/Chemistry/SurrogateModels/LookupKey.hpp @@ -10,9 +10,21 @@ namespace poet { +constexpr std::int8_t SC_NOTATION_EXPONENT_MASK = -128; +constexpr std::int64_t SC_NOTATION_SIGNIFICANT_MASK = 0xFFFFFFFFFFFF; + struct Lookup_SC_notation { std::int8_t exp : 8; std::int64_t significant : 56; + + constexpr static Lookup_SC_notation nan() { + return {SC_NOTATION_EXPONENT_MASK, SC_NOTATION_SIGNIFICANT_MASK}; + } + + constexpr bool isnan() { + return !!(exp == SC_NOTATION_EXPONENT_MASK && + significant == SC_NOTATION_SIGNIFICANT_MASK); + } }; union Lookup_Keyelement { diff --git a/src/Chemistry/SurrogateModels/Rounding.hpp b/src/Chemistry/SurrogateModels/Rounding.hpp index 3f659290e..688ac4707 100644 --- a/src/Chemistry/SurrogateModels/Rounding.hpp +++ b/src/Chemistry/SurrogateModels/Rounding.hpp @@ -20,6 +20,11 @@ class DHT_Rounder { public: Lookup_Keyelement round(const double &value, std::uint32_t signif, bool is_ho) { + + if (std::isnan(value)) { + return {.sc_notation = Lookup_SC_notation::nan()}; + } + std::int8_t exp = static_cast(std::floor(std::log10(std::fabs(value))));