feat: implement caching for interpolation calculations and add NaN handling

This commit is contained in:
Max Lübke 2024-12-13 08:21:25 +01:00
parent 43d2a846c7
commit f6b4ce017a
5 changed files with 37 additions and 1 deletions

View File

@ -8,6 +8,7 @@
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <memory>
#include <mpi.h>

View File

@ -261,6 +261,8 @@ private:
const InitialList::ChemistryHookFunctions &hooks;
const std::vector<std::string> &out_names;
const std::vector<std::string> dht_names;
std::unordered_map<int, std::vector<std::int32_t>> to_calc_cache;
};
} // namespace poet

View File

@ -14,6 +14,7 @@
#include <algorithm>
#include <array>
#include <cassert>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <functional>
@ -116,10 +117,25 @@ void InterpolationModule::tryInterpolation(WorkPackage &work_package) {
this->pht->incrementReadCounter(roundKey(rounded_key));
#endif
const int cell_id = static_cast<int>(work_package.input[wp_i][0]);
if (!to_calc_cache.contains(cell_id)) {
std::vector<std::int32_t> to_calc = dht_instance.getKeyElements();
std::vector<std::int32_t> keep_indices;
for (std::size_t i = 0; i < to_calc.size(); i++) {
if (!std::isnan(work_package.input[wp_i][to_calc[i]])) {
keep_indices.push_back(to_calc[i]);
}
}
to_calc_cache[cell_id] = keep_indices;
}
double start_fc = MPI_Wtime();
work_package.output[wp_i] =
f_interpolate(dht_instance.getKeyElements(), work_package.input[wp_i],
f_interpolate(to_calc_cache[cell_id], work_package.input[wp_i],
pht_result.in_values, pht_result.out_values);
if (hooks.interp_post.isValid()) {

View File

@ -10,9 +10,21 @@
namespace poet {
constexpr std::int8_t SC_NOTATION_EXPONENT_MASK = -128;
constexpr std::int64_t SC_NOTATION_SIGNIFICANT_MASK = 0xFFFFFFFFFFFF;
struct Lookup_SC_notation {
std::int8_t exp : 8;
std::int64_t significant : 56;
constexpr static Lookup_SC_notation nan() {
return {SC_NOTATION_EXPONENT_MASK, SC_NOTATION_SIGNIFICANT_MASK};
}
constexpr bool isnan() {
return !!(exp == SC_NOTATION_EXPONENT_MASK &&
significant == SC_NOTATION_SIGNIFICANT_MASK);
}
};
union Lookup_Keyelement {

View File

@ -20,6 +20,11 @@ class DHT_Rounder {
public:
Lookup_Keyelement round(const double &value, std::uint32_t signif,
bool is_ho) {
if (std::isnan(value)) {
return {.sc_notation = Lookup_SC_notation::nan()};
}
std::int8_t exp =
static_cast<std::int8_t>(std::floor(std::log10(std::fabs(value))));