mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-16 12:54:50 +01:00
feat: add has_het_ids parameter to DHT initialization and related functions
This commit is contained in:
parent
5f56ce9e3f
commit
cd53f43bd2
@ -185,7 +185,8 @@ poet::ChemistryModule::~ChemistryModule() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void poet::ChemistryModule::initializeDHT(
|
void poet::ChemistryModule::initializeDHT(
|
||||||
uint32_t size_mb, const NamedVector<std::uint32_t> &key_species) {
|
uint32_t size_mb, const NamedVector<std::uint32_t> &key_species,
|
||||||
|
bool has_het_ids) {
|
||||||
constexpr uint32_t MB_FACTOR = 1E6;
|
constexpr uint32_t MB_FACTOR = 1E6;
|
||||||
|
|
||||||
MPI_Comm dht_comm;
|
MPI_Comm dht_comm;
|
||||||
@ -217,7 +218,7 @@ void poet::ChemistryModule::initializeDHT(
|
|||||||
|
|
||||||
this->dht = new DHT_Wrapper(dht_comm, dht_size, map_copy, key_indices,
|
this->dht = new DHT_Wrapper(dht_comm, dht_size, map_copy, key_indices,
|
||||||
this->prop_names, params.hooks,
|
this->prop_names, params.hooks,
|
||||||
this->prop_count, interp_enabled);
|
this->prop_count, interp_enabled, has_het_ids);
|
||||||
this->dht->setBaseTotals(base_totals.at(0), base_totals.at(1));
|
this->dht->setBaseTotals(base_totals.at(0), base_totals.at(1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -77,6 +77,7 @@ public:
|
|||||||
struct SurrogateSetup {
|
struct SurrogateSetup {
|
||||||
std::vector<std::string> prop_names;
|
std::vector<std::string> prop_names;
|
||||||
std::array<double, 2> base_totals;
|
std::array<double, 2> base_totals;
|
||||||
|
bool has_het_ids;
|
||||||
|
|
||||||
bool dht_enabled;
|
bool dht_enabled;
|
||||||
std::uint32_t dht_size_mb;
|
std::uint32_t dht_size_mb;
|
||||||
@ -100,7 +101,8 @@ public:
|
|||||||
this->base_totals = setup.base_totals;
|
this->base_totals = setup.base_totals;
|
||||||
|
|
||||||
if (this->dht_enabled || this->interp_enabled) {
|
if (this->dht_enabled || this->interp_enabled) {
|
||||||
this->initializeDHT(setup.dht_size_mb, this->params.dht_species);
|
this->initializeDHT(setup.dht_size_mb, this->params.dht_species,
|
||||||
|
setup.has_het_ids);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this->interp_enabled) {
|
if (this->interp_enabled) {
|
||||||
@ -243,7 +245,8 @@ public:
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
void initializeDHT(uint32_t size_mb,
|
void initializeDHT(uint32_t size_mb,
|
||||||
const NamedVector<std::uint32_t> &key_species);
|
const NamedVector<std::uint32_t> &key_species,
|
||||||
|
bool has_het_ids);
|
||||||
void setDHTSnapshots(int type, const std::string &out_dir);
|
void setDHTSnapshots(int type, const std::string &out_dir);
|
||||||
void setDHTReadFile(const std::string &input_file);
|
void setDHTReadFile(const std::string &input_file);
|
||||||
|
|
||||||
|
|||||||
@ -43,11 +43,12 @@ DHT_Wrapper::DHT_Wrapper(MPI_Comm dht_comm, std::uint64_t dht_size,
|
|||||||
const std::vector<std::int32_t> &key_indices,
|
const std::vector<std::int32_t> &key_indices,
|
||||||
const std::vector<std::string> &_output_names,
|
const std::vector<std::string> &_output_names,
|
||||||
const InitialList::ChemistryHookFunctions &_hooks,
|
const InitialList::ChemistryHookFunctions &_hooks,
|
||||||
uint32_t data_count, bool _with_interp)
|
uint32_t data_count, bool _with_interp,
|
||||||
|
bool _has_het_ids)
|
||||||
: key_count(key_indices.size()), data_count(data_count),
|
: key_count(key_indices.size()), data_count(data_count),
|
||||||
input_key_elements(key_indices), communicator(dht_comm),
|
input_key_elements(key_indices), communicator(dht_comm),
|
||||||
key_species(key_species), output_names(_output_names), hooks(_hooks),
|
key_species(key_species), output_names(_output_names), hooks(_hooks),
|
||||||
with_interp(_with_interp) {
|
with_interp(_with_interp), has_het_ids(_has_het_ids) {
|
||||||
// initialize DHT object
|
// initialize DHT object
|
||||||
// key size = count of key elements + timestep
|
// key size = count of key elements + timestep
|
||||||
uint32_t key_size = (key_count + 1) * sizeof(Lookup_Keyelement);
|
uint32_t key_size = (key_count + 1) * sizeof(Lookup_Keyelement);
|
||||||
@ -270,7 +271,7 @@ LookupKey DHT_Wrapper::fuzzForDHT_R(const std::vector<double> &cell,
|
|||||||
const std::vector<double> eval_vec =
|
const std::vector<double> eval_vec =
|
||||||
Rcpp::as<std::vector<double>>(hooks.dht_fuzz(input_nv));
|
Rcpp::as<std::vector<double>>(hooks.dht_fuzz(input_nv));
|
||||||
assert(eval_vec.size() == this->key_count);
|
assert(eval_vec.size() == this->key_count);
|
||||||
LookupKey vecFuzz(this->key_count + 1, {.0});
|
LookupKey vecFuzz(this->key_count + 1 + has_het_ids, {.0});
|
||||||
|
|
||||||
DHT_Rounder rounder;
|
DHT_Rounder rounder;
|
||||||
|
|
||||||
@ -290,6 +291,9 @@ LookupKey DHT_Wrapper::fuzzForDHT_R(const std::vector<double> &cell,
|
|||||||
}
|
}
|
||||||
// add timestep to the end of the key as double value
|
// add timestep to the end of the key as double value
|
||||||
vecFuzz[this->key_count].fp_element = dt;
|
vecFuzz[this->key_count].fp_element = dt;
|
||||||
|
if (has_het_ids) {
|
||||||
|
vecFuzz[this->key_count + 1].fp_element = cell[0];
|
||||||
|
}
|
||||||
|
|
||||||
return vecFuzz;
|
return vecFuzz;
|
||||||
}
|
}
|
||||||
@ -297,7 +301,7 @@ LookupKey DHT_Wrapper::fuzzForDHT_R(const std::vector<double> &cell,
|
|||||||
LookupKey DHT_Wrapper::fuzzForDHT(const std::vector<double> &cell, double dt) {
|
LookupKey DHT_Wrapper::fuzzForDHT(const std::vector<double> &cell, double dt) {
|
||||||
const auto c_zero_val = std::pow(10, AQUEOUS_EXP);
|
const auto c_zero_val = std::pow(10, AQUEOUS_EXP);
|
||||||
|
|
||||||
LookupKey vecFuzz(this->key_count + 1, {.0});
|
LookupKey vecFuzz(this->key_count + 1 + has_het_ids, {.0});
|
||||||
DHT_Rounder rounder;
|
DHT_Rounder rounder;
|
||||||
|
|
||||||
int totals_i = 0;
|
int totals_i = 0;
|
||||||
@ -323,6 +327,9 @@ LookupKey DHT_Wrapper::fuzzForDHT(const std::vector<double> &cell, double dt) {
|
|||||||
}
|
}
|
||||||
// add timestep to the end of the key as double value
|
// add timestep to the end of the key as double value
|
||||||
vecFuzz[this->key_count].fp_element = dt;
|
vecFuzz[this->key_count].fp_element = dt;
|
||||||
|
if (has_het_ids) {
|
||||||
|
vecFuzz[this->key_count + 1].fp_element = cell[0];
|
||||||
|
}
|
||||||
|
|
||||||
return vecFuzz;
|
return vecFuzz;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -87,7 +87,7 @@ public:
|
|||||||
const std::vector<std::int32_t> &key_indices,
|
const std::vector<std::int32_t> &key_indices,
|
||||||
const std::vector<std::string> &output_names,
|
const std::vector<std::string> &output_names,
|
||||||
const InitialList::ChemistryHookFunctions &hooks,
|
const InitialList::ChemistryHookFunctions &hooks,
|
||||||
uint32_t data_count, bool with_interp);
|
uint32_t data_count, bool with_interp, bool has_het_ids);
|
||||||
/**
|
/**
|
||||||
* @brief Destroy the dht wrapper object
|
* @brief Destroy the dht wrapper object
|
||||||
*
|
*
|
||||||
@ -264,6 +264,7 @@ private:
|
|||||||
DHT_ResultObject dht_results;
|
DHT_ResultObject dht_results;
|
||||||
|
|
||||||
std::array<double, 2> base_totals{0};
|
std::array<double, 2> base_totals{0};
|
||||||
|
bool has_het_ids{false};
|
||||||
};
|
};
|
||||||
} // namespace poet
|
} // namespace poet
|
||||||
|
|
||||||
|
|||||||
27
src/poet.cpp
27
src/poet.cpp
@ -40,6 +40,7 @@
|
|||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <mpi.h>
|
#include <mpi.h>
|
||||||
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include <CLI/CLI.hpp>
|
#include <CLI/CLI.hpp>
|
||||||
@ -510,6 +511,31 @@ std::array<double, 2> getBaseTotals(Field &&field, int root, MPI_Comm comm) {
|
|||||||
return base_totals;
|
return base_totals;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool getHasID(Field &&field, int root, MPI_Comm comm) {
|
||||||
|
bool has_id;
|
||||||
|
|
||||||
|
int rank;
|
||||||
|
MPI_Comm_rank(comm, &rank);
|
||||||
|
|
||||||
|
const bool is_master = root == rank;
|
||||||
|
|
||||||
|
if (is_master) {
|
||||||
|
const auto ID_field = field["ID"];
|
||||||
|
|
||||||
|
std::set<double> unique_IDs(ID_field.begin(), ID_field.end());
|
||||||
|
|
||||||
|
has_id = unique_IDs.size() > 1;
|
||||||
|
|
||||||
|
MPI_Bcast(&has_id, 1, MPI_C_BOOL, root, MPI_COMM_WORLD);
|
||||||
|
|
||||||
|
return has_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
MPI_Bcast(&has_id, 1, MPI_C_BOOL, root, comm);
|
||||||
|
|
||||||
|
return has_id;
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
int world_size;
|
int world_size;
|
||||||
|
|
||||||
@ -558,6 +584,7 @@ int main(int argc, char *argv[]) {
|
|||||||
const ChemistryModule::SurrogateSetup surr_setup = {
|
const ChemistryModule::SurrogateSetup surr_setup = {
|
||||||
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
||||||
getBaseTotals(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
getBaseTotals(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
||||||
|
getHasID(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
||||||
run_params.use_dht,
|
run_params.use_dht,
|
||||||
run_params.dht_size,
|
run_params.dht_size,
|
||||||
run_params.use_interp,
|
run_params.use_interp,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user