feat: add has_het_ids parameter to DHT initialization and related functions

This commit is contained in:
Max Lübke 2024-12-13 14:33:02 +01:00
parent 9b9fd898b7
commit ac96f24a33
5 changed files with 48 additions and 9 deletions

View File

@ -185,7 +185,8 @@ poet::ChemistryModule::~ChemistryModule() {
}
void poet::ChemistryModule::initializeDHT(
uint32_t size_mb, const NamedVector<std::uint32_t> &key_species) {
uint32_t size_mb, const NamedVector<std::uint32_t> &key_species,
bool has_het_ids) {
constexpr uint32_t MB_FACTOR = 1E6;
MPI_Comm dht_comm;
@ -217,7 +218,7 @@ void poet::ChemistryModule::initializeDHT(
this->dht = new DHT_Wrapper(dht_comm, dht_size, map_copy, key_indices,
this->prop_names, params.hooks,
this->prop_count, interp_enabled);
this->prop_count, interp_enabled, has_het_ids);
this->dht->setBaseTotals(base_totals.at(0), base_totals.at(1));
}
}

View File

@ -77,6 +77,7 @@ public:
struct SurrogateSetup {
std::vector<std::string> prop_names;
std::array<double, 2> base_totals;
bool has_het_ids;
bool dht_enabled;
std::uint32_t dht_size_mb;
@ -100,7 +101,8 @@ public:
this->base_totals = setup.base_totals;
if (this->dht_enabled || this->interp_enabled) {
this->initializeDHT(setup.dht_size_mb, this->params.dht_species);
this->initializeDHT(setup.dht_size_mb, this->params.dht_species,
setup.has_het_ids);
}
if (this->interp_enabled) {
@ -243,7 +245,8 @@ public:
protected:
void initializeDHT(uint32_t size_mb,
const NamedVector<std::uint32_t> &key_species);
const NamedVector<std::uint32_t> &key_species,
bool has_het_ids);
void setDHTSnapshots(int type, const std::string &out_dir);
void setDHTReadFile(const std::string &input_file);

View File

@ -43,11 +43,12 @@ DHT_Wrapper::DHT_Wrapper(MPI_Comm dht_comm, std::uint64_t dht_size,
const std::vector<std::int32_t> &key_indices,
const std::vector<std::string> &_output_names,
const InitialList::ChemistryHookFunctions &_hooks,
uint32_t data_count, bool _with_interp)
uint32_t data_count, bool _with_interp,
bool _has_het_ids)
: key_count(key_indices.size()), data_count(data_count),
input_key_elements(key_indices), communicator(dht_comm),
key_species(key_species), output_names(_output_names), hooks(_hooks),
with_interp(_with_interp) {
with_interp(_with_interp), has_het_ids(_has_het_ids) {
// initialize DHT object
// key size = count of key elements + timestep
uint32_t key_size = (key_count + 1) * sizeof(Lookup_Keyelement);
@ -270,7 +271,7 @@ LookupKey DHT_Wrapper::fuzzForDHT_R(const std::vector<double> &cell,
const std::vector<double> eval_vec =
Rcpp::as<std::vector<double>>(hooks.dht_fuzz(input_nv));
assert(eval_vec.size() == this->key_count);
LookupKey vecFuzz(this->key_count + 1, {.0});
LookupKey vecFuzz(this->key_count + 1 + has_het_ids, {.0});
DHT_Rounder rounder;
@ -290,6 +291,9 @@ LookupKey DHT_Wrapper::fuzzForDHT_R(const std::vector<double> &cell,
}
// add timestep to the end of the key as double value
vecFuzz[this->key_count].fp_element = dt;
if (has_het_ids) {
vecFuzz[this->key_count + 1].fp_element = cell[0];
}
return vecFuzz;
}
@ -297,7 +301,7 @@ LookupKey DHT_Wrapper::fuzzForDHT_R(const std::vector<double> &cell,
LookupKey DHT_Wrapper::fuzzForDHT(const std::vector<double> &cell, double dt) {
const auto c_zero_val = std::pow(10, AQUEOUS_EXP);
LookupKey vecFuzz(this->key_count + 1, {.0});
LookupKey vecFuzz(this->key_count + 1 + has_het_ids, {.0});
DHT_Rounder rounder;
int totals_i = 0;
@ -323,6 +327,9 @@ LookupKey DHT_Wrapper::fuzzForDHT(const std::vector<double> &cell, double dt) {
}
// add timestep to the end of the key as double value
vecFuzz[this->key_count].fp_element = dt;
if (has_het_ids) {
vecFuzz[this->key_count + 1].fp_element = cell[0];
}
return vecFuzz;
}

View File

@ -87,7 +87,7 @@ public:
const std::vector<std::int32_t> &key_indices,
const std::vector<std::string> &output_names,
const InitialList::ChemistryHookFunctions &hooks,
uint32_t data_count, bool with_interp);
uint32_t data_count, bool with_interp, bool has_het_ids);
/**
* @brief Destroy the dht wrapper object
*
@ -264,6 +264,7 @@ private:
DHT_ResultObject dht_results;
std::array<double, 2> base_totals{0};
bool has_het_ids{false};
};
} // namespace poet

View File

@ -40,6 +40,7 @@
#include <cstdlib>
#include <memory>
#include <mpi.h>
#include <set>
#include <string>
#include <CLI/CLI.hpp>
@ -510,6 +511,31 @@ std::array<double, 2> getBaseTotals(Field &&field, int root, MPI_Comm comm) {
return base_totals;
}
bool getHasID(Field &&field, int root, MPI_Comm comm) {
bool has_id;
int rank;
MPI_Comm_rank(comm, &rank);
const bool is_master = root == rank;
if (is_master) {
const auto ID_field = field["ID"];
std::set<double> unique_IDs(ID_field.begin(), ID_field.end());
has_id = unique_IDs.size() > 1;
MPI_Bcast(&has_id, 1, MPI_C_BOOL, root, MPI_COMM_WORLD);
return has_id;
}
MPI_Bcast(&has_id, 1, MPI_C_BOOL, root, comm);
return has_id;
}
int main(int argc, char *argv[]) {
int world_size;
@ -558,6 +584,7 @@ int main(int argc, char *argv[]) {
const ChemistryModule::SurrogateSetup surr_setup = {
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
getBaseTotals(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
getHasID(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
run_params.use_dht,
run_params.dht_size,
run_params.use_interp,