mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-15 20:38:23 +01:00
feat: add has_het_ids parameter to DHT initialization and related functions
This commit is contained in:
parent
9b9fd898b7
commit
ac96f24a33
@ -185,7 +185,8 @@ poet::ChemistryModule::~ChemistryModule() {
|
||||
}
|
||||
|
||||
void poet::ChemistryModule::initializeDHT(
|
||||
uint32_t size_mb, const NamedVector<std::uint32_t> &key_species) {
|
||||
uint32_t size_mb, const NamedVector<std::uint32_t> &key_species,
|
||||
bool has_het_ids) {
|
||||
constexpr uint32_t MB_FACTOR = 1E6;
|
||||
|
||||
MPI_Comm dht_comm;
|
||||
@ -217,7 +218,7 @@ void poet::ChemistryModule::initializeDHT(
|
||||
|
||||
this->dht = new DHT_Wrapper(dht_comm, dht_size, map_copy, key_indices,
|
||||
this->prop_names, params.hooks,
|
||||
this->prop_count, interp_enabled);
|
||||
this->prop_count, interp_enabled, has_het_ids);
|
||||
this->dht->setBaseTotals(base_totals.at(0), base_totals.at(1));
|
||||
}
|
||||
}
|
||||
|
||||
@ -77,6 +77,7 @@ public:
|
||||
struct SurrogateSetup {
|
||||
std::vector<std::string> prop_names;
|
||||
std::array<double, 2> base_totals;
|
||||
bool has_het_ids;
|
||||
|
||||
bool dht_enabled;
|
||||
std::uint32_t dht_size_mb;
|
||||
@ -100,7 +101,8 @@ public:
|
||||
this->base_totals = setup.base_totals;
|
||||
|
||||
if (this->dht_enabled || this->interp_enabled) {
|
||||
this->initializeDHT(setup.dht_size_mb, this->params.dht_species);
|
||||
this->initializeDHT(setup.dht_size_mb, this->params.dht_species,
|
||||
setup.has_het_ids);
|
||||
}
|
||||
|
||||
if (this->interp_enabled) {
|
||||
@ -243,7 +245,8 @@ public:
|
||||
|
||||
protected:
|
||||
void initializeDHT(uint32_t size_mb,
|
||||
const NamedVector<std::uint32_t> &key_species);
|
||||
const NamedVector<std::uint32_t> &key_species,
|
||||
bool has_het_ids);
|
||||
void setDHTSnapshots(int type, const std::string &out_dir);
|
||||
void setDHTReadFile(const std::string &input_file);
|
||||
|
||||
|
||||
@ -43,11 +43,12 @@ DHT_Wrapper::DHT_Wrapper(MPI_Comm dht_comm, std::uint64_t dht_size,
|
||||
const std::vector<std::int32_t> &key_indices,
|
||||
const std::vector<std::string> &_output_names,
|
||||
const InitialList::ChemistryHookFunctions &_hooks,
|
||||
uint32_t data_count, bool _with_interp)
|
||||
uint32_t data_count, bool _with_interp,
|
||||
bool _has_het_ids)
|
||||
: key_count(key_indices.size()), data_count(data_count),
|
||||
input_key_elements(key_indices), communicator(dht_comm),
|
||||
key_species(key_species), output_names(_output_names), hooks(_hooks),
|
||||
with_interp(_with_interp) {
|
||||
with_interp(_with_interp), has_het_ids(_has_het_ids) {
|
||||
// initialize DHT object
|
||||
// key size = count of key elements + timestep
|
||||
uint32_t key_size = (key_count + 1) * sizeof(Lookup_Keyelement);
|
||||
@ -270,7 +271,7 @@ LookupKey DHT_Wrapper::fuzzForDHT_R(const std::vector<double> &cell,
|
||||
const std::vector<double> eval_vec =
|
||||
Rcpp::as<std::vector<double>>(hooks.dht_fuzz(input_nv));
|
||||
assert(eval_vec.size() == this->key_count);
|
||||
LookupKey vecFuzz(this->key_count + 1, {.0});
|
||||
LookupKey vecFuzz(this->key_count + 1 + has_het_ids, {.0});
|
||||
|
||||
DHT_Rounder rounder;
|
||||
|
||||
@ -290,6 +291,9 @@ LookupKey DHT_Wrapper::fuzzForDHT_R(const std::vector<double> &cell,
|
||||
}
|
||||
// add timestep to the end of the key as double value
|
||||
vecFuzz[this->key_count].fp_element = dt;
|
||||
if (has_het_ids) {
|
||||
vecFuzz[this->key_count + 1].fp_element = cell[0];
|
||||
}
|
||||
|
||||
return vecFuzz;
|
||||
}
|
||||
@ -297,7 +301,7 @@ LookupKey DHT_Wrapper::fuzzForDHT_R(const std::vector<double> &cell,
|
||||
LookupKey DHT_Wrapper::fuzzForDHT(const std::vector<double> &cell, double dt) {
|
||||
const auto c_zero_val = std::pow(10, AQUEOUS_EXP);
|
||||
|
||||
LookupKey vecFuzz(this->key_count + 1, {.0});
|
||||
LookupKey vecFuzz(this->key_count + 1 + has_het_ids, {.0});
|
||||
DHT_Rounder rounder;
|
||||
|
||||
int totals_i = 0;
|
||||
@ -323,6 +327,9 @@ LookupKey DHT_Wrapper::fuzzForDHT(const std::vector<double> &cell, double dt) {
|
||||
}
|
||||
// add timestep to the end of the key as double value
|
||||
vecFuzz[this->key_count].fp_element = dt;
|
||||
if (has_het_ids) {
|
||||
vecFuzz[this->key_count + 1].fp_element = cell[0];
|
||||
}
|
||||
|
||||
return vecFuzz;
|
||||
}
|
||||
|
||||
@ -87,7 +87,7 @@ public:
|
||||
const std::vector<std::int32_t> &key_indices,
|
||||
const std::vector<std::string> &output_names,
|
||||
const InitialList::ChemistryHookFunctions &hooks,
|
||||
uint32_t data_count, bool with_interp);
|
||||
uint32_t data_count, bool with_interp, bool has_het_ids);
|
||||
/**
|
||||
* @brief Destroy the dht wrapper object
|
||||
*
|
||||
@ -264,6 +264,7 @@ private:
|
||||
DHT_ResultObject dht_results;
|
||||
|
||||
std::array<double, 2> base_totals{0};
|
||||
bool has_het_ids{false};
|
||||
};
|
||||
} // namespace poet
|
||||
|
||||
|
||||
27
src/poet.cpp
27
src/poet.cpp
@ -40,6 +40,7 @@
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <mpi.h>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include <CLI/CLI.hpp>
|
||||
@ -510,6 +511,31 @@ std::array<double, 2> getBaseTotals(Field &&field, int root, MPI_Comm comm) {
|
||||
return base_totals;
|
||||
}
|
||||
|
||||
bool getHasID(Field &&field, int root, MPI_Comm comm) {
|
||||
bool has_id;
|
||||
|
||||
int rank;
|
||||
MPI_Comm_rank(comm, &rank);
|
||||
|
||||
const bool is_master = root == rank;
|
||||
|
||||
if (is_master) {
|
||||
const auto ID_field = field["ID"];
|
||||
|
||||
std::set<double> unique_IDs(ID_field.begin(), ID_field.end());
|
||||
|
||||
has_id = unique_IDs.size() > 1;
|
||||
|
||||
MPI_Bcast(&has_id, 1, MPI_C_BOOL, root, MPI_COMM_WORLD);
|
||||
|
||||
return has_id;
|
||||
}
|
||||
|
||||
MPI_Bcast(&has_id, 1, MPI_C_BOOL, root, comm);
|
||||
|
||||
return has_id;
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
int world_size;
|
||||
|
||||
@ -558,6 +584,7 @@ int main(int argc, char *argv[]) {
|
||||
const ChemistryModule::SurrogateSetup surr_setup = {
|
||||
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
||||
getBaseTotals(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
||||
getHasID(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
||||
run_params.use_dht,
|
||||
run_params.dht_size,
|
||||
run_params.use_interp,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user