mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-15 20:38:23 +01:00
Merge branch 'ml/fix-interpolation' into 'main'
Fix Interpolation and add optional features to POET benchmark description See merge request naaice/poet!48
This commit is contained in:
commit
ab01dbc0a7
@ -35,15 +35,15 @@ diffusion_setup <- list(
|
||||
)
|
||||
|
||||
dht_species <- c(
|
||||
"H" = 7,
|
||||
"O" = 7,
|
||||
"Charge" = 4,
|
||||
"Ba" = 7,
|
||||
"Cl" = 7,
|
||||
"S(6)" = 7,
|
||||
"Sr" = 7,
|
||||
"Barite" = 4,
|
||||
"Celestite" = 4
|
||||
"H" = 3,
|
||||
"O" = 3,
|
||||
"Charge" = 6,
|
||||
"Ba" = 6,
|
||||
"Cl" = 6,
|
||||
"S" = 6,
|
||||
"Sr" = 6,
|
||||
"Barite" = 5,
|
||||
"Celestite" = 5
|
||||
)
|
||||
|
||||
chemistry_setup <- list(
|
||||
|
||||
@ -11,9 +11,9 @@ grid_def <- matrix(2, nrow = rows, ncol = cols)
|
||||
grid_setup <- list(
|
||||
pqc_in_file = "./barite.pqi",
|
||||
pqc_db_file = "./db_barite.dat", ## Path to the database file for Phreeqc
|
||||
grid_def = grid_def, ## Definition of the grid, containing IDs according to the Phreeqc input script
|
||||
grid_size = c(s_rows, s_cols), ## Size of the grid in meters
|
||||
constant_cells = c() ## IDs of cells with constant concentration
|
||||
grid_def = grid_def, ## Definition of the grid, containing IDs according to the Phreeqc input script
|
||||
grid_size = c(s_rows, s_cols), ## Size of the grid in meters
|
||||
constant_cells = c() ## IDs of cells with constant concentration
|
||||
)
|
||||
|
||||
bound_length <- 2
|
||||
@ -36,15 +36,15 @@ diffusion_setup <- list(
|
||||
)
|
||||
|
||||
dht_species <- c(
|
||||
"H" = 4,
|
||||
"O" = 10,
|
||||
"Charge" = 4,
|
||||
"Ba" = 7,
|
||||
"Cl" = 4,
|
||||
"S(6)" = 7,
|
||||
"Sr" = 4,
|
||||
"Barite" = 2,
|
||||
"Celestite" = 2
|
||||
"H" = 3,
|
||||
"O" = 3,
|
||||
"Charge" = 3,
|
||||
"Ba" = 6,
|
||||
"Cl" = 6,
|
||||
"S" = 6,
|
||||
"Sr" = 6,
|
||||
"Barite" = 5,
|
||||
"Celestite" = 5
|
||||
)
|
||||
|
||||
chemistry_setup <- list(
|
||||
|
||||
@ -1 +1 @@
|
||||
Subproject commit 38268b4aad03e6ce4755315f4cd690f007fa2720
|
||||
Subproject commit 6e727e2f896e853745b4dd123c5772a9b40ad705
|
||||
@ -14,10 +14,7 @@ struct WorkPackage {
|
||||
std::vector<std::vector<double>> output;
|
||||
std::vector<std::uint8_t> mapping;
|
||||
|
||||
WorkPackage(std::size_t _size) : size(_size) {
|
||||
input.resize(size);
|
||||
output.resize(size);
|
||||
mapping.resize(size, CHEM_PQC);
|
||||
}
|
||||
WorkPackage(std::size_t _size)
|
||||
: size(_size), input(size), output(size), mapping(size, CHEM_PQC) {}
|
||||
};
|
||||
} // namespace poet
|
||||
@ -8,6 +8,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <mpi.h>
|
||||
@ -65,7 +66,8 @@ inverseDistanceWeighting(const std::vector<std::int32_t> &to_calc,
|
||||
distance += std::pow(
|
||||
rescaled[key_comp_i][point_i] - rescaled[key_comp_i][data_set_n], 2);
|
||||
}
|
||||
weights[point_i] = 1 / std::sqrt(distance);
|
||||
|
||||
weights[point_i] = distance != 0 ? 1 / std::sqrt(distance) : 0;
|
||||
assert(!std::isnan(weights[point_i]));
|
||||
inv_sum += weights[point_i];
|
||||
}
|
||||
@ -96,63 +98,9 @@ inverseDistanceWeighting(const std::vector<std::int32_t> &to_calc,
|
||||
key_delta /= inv_sum;
|
||||
|
||||
results[output_comp_i] = from[output_comp_i] + key_delta;
|
||||
assert(!std::isnan(results[output_comp_i]));
|
||||
}
|
||||
|
||||
// if (!has_h) {
|
||||
// double new_val = 0;
|
||||
// for (int j = 0; j < data_set_n; j++) {
|
||||
// new_val += weights[j] * output[j][0];
|
||||
// }
|
||||
// results[0] = new_val / inv_sum;
|
||||
// }
|
||||
|
||||
// if (!has_h) {
|
||||
// double new_val = 0;
|
||||
// for (int j = 0; j < data_set_n; j++) {
|
||||
// new_val += weights[j] * output[j][1];
|
||||
// }
|
||||
// results[1] = new_val / inv_sum;
|
||||
// }
|
||||
|
||||
// for (std::uint32_t i = 0; i < to_calc.size(); i++) {
|
||||
// const std::uint32_t interp_i = to_calc[i];
|
||||
|
||||
// // rescale input between 0 and 1
|
||||
// for (int j = 0; j < input.size(); j++) {
|
||||
// buffer[j] = input[j].at(i);
|
||||
// }
|
||||
|
||||
// buffer[buffer_size - 1] = from[interp_i];
|
||||
|
||||
// const double min = *std::min_element(buffer, buffer + buffer_size);
|
||||
// const double max = *std::max_element(buffer, buffer + buffer_size);
|
||||
|
||||
// for (int j = 0; j < input.size(); j++) {
|
||||
// buffer[j] = ((max - min) != 0 ? (buffer[j] - min) / (max - min) : 1);
|
||||
// }
|
||||
// from_rescaled =
|
||||
// ((max - min) != 0 ? (from[interp_i] - min) / (max - min) : 0);
|
||||
|
||||
// double inv_sum = 0;
|
||||
|
||||
// // calculate distances for each point
|
||||
// for (int i = 0; i < input.size(); i++) {
|
||||
// const double distance = std::pow(buffer[i] - from_rescaled, 2);
|
||||
|
||||
// buffer[i] = distance > 0 ? (1 / std::sqrt(distance)) : 0;
|
||||
// inv_sum += buffer[i];
|
||||
// }
|
||||
// // calculate new values
|
||||
// double new_val = 0;
|
||||
// for (int i = 0; i < output.size(); i++) {
|
||||
// new_val += buffer[i] * output[i][interp_i];
|
||||
// }
|
||||
// results[interp_i] = new_val / inv_sum;
|
||||
// if (std::isnan(results[interp_i])) {
|
||||
// std::cout << "nan with new_val = " << output[0][i] << std::endl;
|
||||
// }
|
||||
// }
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
@ -170,7 +118,8 @@ poet::ChemistryModule::ChemistryModule(
|
||||
|
||||
if (!is_master) {
|
||||
PhreeqcMatrix pqc_mat =
|
||||
PhreeqcMatrix(chem_params.database, chem_params.pqc_script);
|
||||
PhreeqcMatrix(chem_params.database, chem_params.pqc_script,
|
||||
chem_params.with_h0_o0, chem_params.with_redox);
|
||||
|
||||
this->pqc_runner =
|
||||
std::make_unique<PhreeqcRunner>(pqc_mat.subset(chem_params.pqc_ids));
|
||||
@ -184,11 +133,10 @@ poet::ChemistryModule::~ChemistryModule() {
|
||||
}
|
||||
|
||||
void poet::ChemistryModule::initializeDHT(
|
||||
uint32_t size_mb, const NamedVector<std::uint32_t> &key_species) {
|
||||
uint32_t size_mb, const NamedVector<std::uint32_t> &key_species,
|
||||
bool has_het_ids) {
|
||||
constexpr uint32_t MB_FACTOR = 1E6;
|
||||
|
||||
this->dht_enabled = true;
|
||||
|
||||
MPI_Comm dht_comm;
|
||||
|
||||
if (this->is_master) {
|
||||
@ -218,7 +166,7 @@ void poet::ChemistryModule::initializeDHT(
|
||||
|
||||
this->dht = new DHT_Wrapper(dht_comm, dht_size, map_copy, key_indices,
|
||||
this->prop_names, params.hooks,
|
||||
this->prop_count, interp_enabled);
|
||||
this->prop_count, interp_enabled, has_het_ids);
|
||||
this->dht->setBaseTotals(base_totals.at(0), base_totals.at(1));
|
||||
}
|
||||
}
|
||||
@ -309,9 +257,10 @@ void poet::ChemistryModule::initializeInterp(
|
||||
map_copy = this->dht->getKeySpecies();
|
||||
for (auto i = 0; i < map_copy.size(); i++) {
|
||||
const std::uint32_t signif =
|
||||
static_cast<std::uint32_t>(map_copy[i]) - (map_copy[i] > InterpolationModule::COARSE_DIFF
|
||||
? InterpolationModule::COARSE_DIFF
|
||||
: 0);
|
||||
static_cast<std::uint32_t>(map_copy[i]) -
|
||||
(map_copy[i] > InterpolationModule::COARSE_DIFF
|
||||
? InterpolationModule::COARSE_DIFF
|
||||
: 0);
|
||||
map_copy[i] = signif;
|
||||
}
|
||||
}
|
||||
@ -368,7 +317,8 @@ void poet::ChemistryModule::unshuffleField(const std::vector<double> &in_buffer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void poet::ChemistryModule::set_ai_surrogate_validity_vector(std::vector<int> r_vector) {
|
||||
|
||||
void poet::ChemistryModule::set_ai_surrogate_validity_vector(
|
||||
std::vector<int> r_vector) {
|
||||
this->ai_surrogate_validity_vector = r_vector;
|
||||
}
|
||||
|
||||
@ -76,9 +76,13 @@ public:
|
||||
|
||||
struct SurrogateSetup {
|
||||
std::vector<std::string> prop_names;
|
||||
std::array<double, 2> base_totals;
|
||||
bool has_het_ids;
|
||||
|
||||
bool dht_enabled;
|
||||
std::uint32_t dht_size_mb;
|
||||
int dht_snaps;
|
||||
std::string dht_out_dir;
|
||||
|
||||
bool interp_enabled;
|
||||
std::uint32_t interp_bucket_size;
|
||||
@ -96,8 +100,15 @@ public:
|
||||
this->interp_enabled = setup.interp_enabled;
|
||||
this->ai_surrogate_enabled = setup.ai_surrogate_enabled;
|
||||
|
||||
this->base_totals = setup.base_totals;
|
||||
|
||||
if (this->dht_enabled || this->interp_enabled) {
|
||||
this->initializeDHT(setup.dht_size_mb, this->params.dht_species);
|
||||
this->initializeDHT(setup.dht_size_mb, this->params.dht_species,
|
||||
setup.has_het_ids);
|
||||
|
||||
if (setup.dht_snaps != DHT_SNAPS_DISABLED) {
|
||||
this->setDHTSnapshots(setup.dht_snaps, setup.dht_out_dir);
|
||||
}
|
||||
}
|
||||
|
||||
if (this->interp_enabled) {
|
||||
@ -223,8 +234,8 @@ public:
|
||||
};
|
||||
|
||||
/**
|
||||
* **Master only** Set the ai surrogate validity vector from R
|
||||
*/
|
||||
* **Master only** Set the ai surrogate validity vector from R
|
||||
*/
|
||||
void set_ai_surrogate_validity_vector(std::vector<int> r_vector);
|
||||
|
||||
std::vector<uint32_t> GetWorkerInterpolationCalls() const;
|
||||
@ -240,7 +251,8 @@ public:
|
||||
|
||||
protected:
|
||||
void initializeDHT(uint32_t size_mb,
|
||||
const NamedVector<std::uint32_t> &key_species);
|
||||
const NamedVector<std::uint32_t> &key_species,
|
||||
bool has_het_ids);
|
||||
void setDHTSnapshots(int type, const std::string &out_dir);
|
||||
void setDHTReadFile(const std::string &input_file);
|
||||
|
||||
|
||||
@ -43,11 +43,12 @@ DHT_Wrapper::DHT_Wrapper(MPI_Comm dht_comm, std::uint64_t dht_size,
|
||||
const std::vector<std::int32_t> &key_indices,
|
||||
const std::vector<std::string> &_output_names,
|
||||
const InitialList::ChemistryHookFunctions &_hooks,
|
||||
uint32_t data_count, bool _with_interp)
|
||||
uint32_t data_count, bool _with_interp,
|
||||
bool _has_het_ids)
|
||||
: key_count(key_indices.size()), data_count(data_count),
|
||||
input_key_elements(key_indices), communicator(dht_comm),
|
||||
key_species(key_species), output_names(_output_names), hooks(_hooks),
|
||||
with_interp(_with_interp) {
|
||||
with_interp(_with_interp), has_het_ids(_has_het_ids) {
|
||||
// initialize DHT object
|
||||
// key size = count of key elements + timestep
|
||||
uint32_t key_size = (key_count + 1) * sizeof(Lookup_Keyelement);
|
||||
@ -128,42 +129,43 @@ void DHT_Wrapper::fillDHT(const WorkPackage &work_package) {
|
||||
dht_results.filledDHT = std::vector<bool>(length, false);
|
||||
for (int i = 0; i < length; i++) {
|
||||
// If true grid cell was simulated, needs to be inserted into dht
|
||||
if (work_package.mapping[i] == CHEM_PQC) {
|
||||
|
||||
// check if calcite or dolomite is absent and present, resp.n and vice
|
||||
// versa in input/output. If this is the case -> Do not write to DHT!
|
||||
// HACK: hardcoded, should be fixed!
|
||||
if (hooks.dht_fill.isValid()) {
|
||||
NamedVector<double> old_values(output_names, work_package.input[i]);
|
||||
NamedVector<double> new_values(output_names, work_package.output[i]);
|
||||
|
||||
if (hooks.dht_fill(old_values, new_values)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t proc, index;
|
||||
auto &key = dht_results.keys[i];
|
||||
const auto data =
|
||||
(with_interp ? outputToInputAndRates(work_package.input[i],
|
||||
work_package.output[i])
|
||||
: work_package.output[i]);
|
||||
// void *data = (void *)&(work_package[i * this->data_count]);
|
||||
// fuzz data (round, logarithm etc.)
|
||||
|
||||
// insert simulated data with fuzzed key into DHT
|
||||
int res = DHT_write(this->dht_object, key.data(),
|
||||
const_cast<double *>(data.data()), &proc, &index);
|
||||
|
||||
dht_results.locations[i] = {proc, index};
|
||||
|
||||
// if data was successfully written ...
|
||||
if ((res != DHT_SUCCESS) && (res == DHT_WRITE_SUCCESS_WITH_EVICTION)) {
|
||||
dht_evictions++;
|
||||
}
|
||||
|
||||
dht_results.filledDHT[i] = true;
|
||||
if (work_package.mapping[i] != CHEM_PQC) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// check if calcite or dolomite is absent and present, resp.n and vice
|
||||
// versa in input/output. If this is the case -> Do not write to DHT!
|
||||
// HACK: hardcoded, should be fixed!
|
||||
if (hooks.dht_fill.isValid()) {
|
||||
NamedVector<double> old_values(output_names, work_package.input[i]);
|
||||
NamedVector<double> new_values(output_names, work_package.output[i]);
|
||||
|
||||
if (hooks.dht_fill(old_values, new_values)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t proc, index;
|
||||
auto &key = dht_results.keys[i];
|
||||
const auto data =
|
||||
(with_interp ? outputToInputAndRates(work_package.input[i],
|
||||
work_package.output[i])
|
||||
: work_package.output[i]);
|
||||
// void *data = (void *)&(work_package[i * this->data_count]);
|
||||
// fuzz data (round, logarithm etc.)
|
||||
|
||||
// insert simulated data with fuzzed key into DHT
|
||||
int res = DHT_write(this->dht_object, key.data(),
|
||||
const_cast<double *>(data.data()), &proc, &index);
|
||||
|
||||
dht_results.locations[i] = {proc, index};
|
||||
|
||||
// if data was successfully written ...
|
||||
if ((res != DHT_SUCCESS) && (res == DHT_WRITE_SUCCESS_WITH_EVICTION)) {
|
||||
dht_evictions++;
|
||||
}
|
||||
|
||||
dht_results.filledDHT[i] = true;
|
||||
}
|
||||
}
|
||||
|
||||
@ -270,7 +272,7 @@ LookupKey DHT_Wrapper::fuzzForDHT_R(const std::vector<double> &cell,
|
||||
const std::vector<double> eval_vec =
|
||||
Rcpp::as<std::vector<double>>(hooks.dht_fuzz(input_nv));
|
||||
assert(eval_vec.size() == this->key_count);
|
||||
LookupKey vecFuzz(this->key_count + 1, {.0});
|
||||
LookupKey vecFuzz(this->key_count + 1 + has_het_ids, {.0});
|
||||
|
||||
DHT_Rounder rounder;
|
||||
|
||||
@ -290,6 +292,9 @@ LookupKey DHT_Wrapper::fuzzForDHT_R(const std::vector<double> &cell,
|
||||
}
|
||||
// add timestep to the end of the key as double value
|
||||
vecFuzz[this->key_count].fp_element = dt;
|
||||
if (has_het_ids) {
|
||||
vecFuzz[this->key_count + 1].fp_element = cell[0];
|
||||
}
|
||||
|
||||
return vecFuzz;
|
||||
}
|
||||
@ -297,7 +302,7 @@ LookupKey DHT_Wrapper::fuzzForDHT_R(const std::vector<double> &cell,
|
||||
LookupKey DHT_Wrapper::fuzzForDHT(const std::vector<double> &cell, double dt) {
|
||||
const auto c_zero_val = std::pow(10, AQUEOUS_EXP);
|
||||
|
||||
LookupKey vecFuzz(this->key_count + 1, {.0});
|
||||
LookupKey vecFuzz(this->key_count + 1 + has_het_ids, {.0});
|
||||
DHT_Rounder rounder;
|
||||
|
||||
int totals_i = 0;
|
||||
@ -323,6 +328,9 @@ LookupKey DHT_Wrapper::fuzzForDHT(const std::vector<double> &cell, double dt) {
|
||||
}
|
||||
// add timestep to the end of the key as double value
|
||||
vecFuzz[this->key_count].fp_element = dt;
|
||||
if (has_het_ids) {
|
||||
vecFuzz[this->key_count + 1].fp_element = cell[0];
|
||||
}
|
||||
|
||||
return vecFuzz;
|
||||
}
|
||||
|
||||
@ -87,7 +87,7 @@ public:
|
||||
const std::vector<std::int32_t> &key_indices,
|
||||
const std::vector<std::string> &output_names,
|
||||
const InitialList::ChemistryHookFunctions &hooks,
|
||||
uint32_t data_count, bool with_interp);
|
||||
uint32_t data_count, bool with_interp, bool has_het_ids);
|
||||
/**
|
||||
* @brief Destroy the dht wrapper object
|
||||
*
|
||||
@ -264,6 +264,7 @@ private:
|
||||
DHT_ResultObject dht_results;
|
||||
|
||||
std::array<double, 2> base_totals{0};
|
||||
bool has_het_ids{false};
|
||||
};
|
||||
} // namespace poet
|
||||
|
||||
|
||||
@ -261,6 +261,8 @@ private:
|
||||
const InitialList::ChemistryHookFunctions &hooks;
|
||||
const std::vector<std::string> &out_names;
|
||||
const std::vector<std::string> dht_names;
|
||||
|
||||
std::unordered_map<int, std::vector<std::int32_t>> to_calc_cache;
|
||||
};
|
||||
} // namespace poet
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
@ -116,10 +117,25 @@ void InterpolationModule::tryInterpolation(WorkPackage &work_package) {
|
||||
this->pht->incrementReadCounter(roundKey(rounded_key));
|
||||
#endif
|
||||
|
||||
const int cell_id = static_cast<int>(work_package.input[wp_i][0]);
|
||||
|
||||
if (!to_calc_cache.contains(cell_id)) {
|
||||
const std::vector<std::int32_t> &to_calc = dht_instance.getKeyElements();
|
||||
std::vector<std::int32_t> keep_indices;
|
||||
|
||||
for (std::size_t i = 0; i < to_calc.size(); i++) {
|
||||
if (!std::isnan(work_package.input[wp_i][to_calc[i]])) {
|
||||
keep_indices.push_back(to_calc[i]);
|
||||
}
|
||||
}
|
||||
|
||||
to_calc_cache[cell_id] = keep_indices;
|
||||
}
|
||||
|
||||
double start_fc = MPI_Wtime();
|
||||
|
||||
work_package.output[wp_i] =
|
||||
f_interpolate(dht_instance.getKeyElements(), work_package.input[wp_i],
|
||||
f_interpolate(to_calc_cache[cell_id], work_package.input[wp_i],
|
||||
pht_result.in_values, pht_result.out_values);
|
||||
|
||||
if (hooks.interp_post.isValid()) {
|
||||
|
||||
@ -10,9 +10,21 @@
|
||||
|
||||
namespace poet {
|
||||
|
||||
constexpr std::int8_t SC_NOTATION_EXPONENT_MASK = -128;
|
||||
constexpr std::int64_t SC_NOTATION_SIGNIFICANT_MASK = 0xFFFFFFFFFFFF;
|
||||
|
||||
struct Lookup_SC_notation {
|
||||
std::int8_t exp : 8;
|
||||
std::int64_t significant : 56;
|
||||
|
||||
constexpr static Lookup_SC_notation nan() {
|
||||
return {SC_NOTATION_EXPONENT_MASK, SC_NOTATION_SIGNIFICANT_MASK};
|
||||
}
|
||||
|
||||
constexpr bool isnan() const {
|
||||
return !!(exp == SC_NOTATION_EXPONENT_MASK &&
|
||||
significant == SC_NOTATION_SIGNIFICANT_MASK);
|
||||
}
|
||||
};
|
||||
|
||||
union Lookup_Keyelement {
|
||||
@ -23,6 +35,10 @@ union Lookup_Keyelement {
|
||||
return std::memcmp(this, &other, sizeof(Lookup_Keyelement)) == 0 ? true
|
||||
: false;
|
||||
}
|
||||
|
||||
template <typename T> bool operator>(const T &other) const {
|
||||
return this->sc_notation.significant > other;
|
||||
}
|
||||
};
|
||||
|
||||
class LookupKey : public std::vector<Lookup_Keyelement> {
|
||||
|
||||
@ -20,6 +20,11 @@ class DHT_Rounder {
|
||||
public:
|
||||
Lookup_Keyelement round(const double &value, std::uint32_t signif,
|
||||
bool is_ho) {
|
||||
|
||||
if (std::isnan(value)) {
|
||||
return {.sc_notation = Lookup_SC_notation::nan()};
|
||||
}
|
||||
|
||||
std::int8_t exp =
|
||||
static_cast<std::int8_t>(std::floor(std::log10(std::fabs(value))));
|
||||
|
||||
@ -60,6 +65,14 @@ public:
|
||||
std::uint32_t signif) {
|
||||
Lookup_Keyelement new_val = value;
|
||||
|
||||
if (value.sc_notation.isnan()) {
|
||||
return {.sc_notation = Lookup_SC_notation::nan()};
|
||||
}
|
||||
|
||||
if (signif == 0) {
|
||||
return {.sc_notation = {0, value > 0}};
|
||||
}
|
||||
|
||||
std::uint32_t diff_signif =
|
||||
static_cast<std::uint32_t>(
|
||||
std::ceil(std::log10(std::abs(value.sc_notation.significant)))) -
|
||||
|
||||
@ -69,6 +69,8 @@ InitialList::ChemistryInit InitialList::getChemistryInit() const {
|
||||
chem_init.database = database;
|
||||
chem_init.pqc_script = pqc_script;
|
||||
chem_init.pqc_ids = pqc_ids;
|
||||
chem_init.with_h0_o0 = with_h0_o0;
|
||||
chem_init.with_redox = with_redox;
|
||||
// chem_init.pqc_scripts = pqc_scripts;
|
||||
// chem_init.pqc_ids = pqc_ids;
|
||||
|
||||
|
||||
@ -6,10 +6,8 @@
|
||||
#include <Rcpp/Function.h>
|
||||
#include <Rcpp/vector/Matrix.h>
|
||||
#include <Rcpp/vector/instantiation.h>
|
||||
#include <cstdint>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <regex>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
@ -17,42 +15,6 @@
|
||||
|
||||
namespace poet {
|
||||
|
||||
// static Rcpp::NumericMatrix pqcMatToR(const PhreeqcMatrix &phreeqc, RInside
|
||||
// &R) {
|
||||
|
||||
// PhreeqcMatrix::STLExport phreeqc_mat = phreeqc.get();
|
||||
|
||||
// // PhreeqcInit::PhreeqcMat phreeqc_mat = phreeqc->getPhreeqcMat();
|
||||
|
||||
// // add "id" to the front of the column names
|
||||
|
||||
// const std::size_t ncols = phreeqc_mat.names.size();
|
||||
// const std::size_t nrows = phreeqc_mat.values.size();
|
||||
|
||||
// phreeqc_mat.names.insert(phreeqc_mat.names.begin(), "ID");
|
||||
|
||||
// Rcpp::NumericMatrix mat(nrows, ncols + 1);
|
||||
|
||||
// for (std::size_t i = 0; i < nrows; i++) {
|
||||
// mat(i, 0) = phreeqc_mat.ids[i];
|
||||
// for (std::size_t j = 0; j < ncols; ++j) {
|
||||
// mat(i, j + 1) = phreeqc_mat.values[i][j];
|
||||
// }
|
||||
// }
|
||||
|
||||
// Rcpp::colnames(mat) = Rcpp::wrap(phreeqc_mat.names);
|
||||
|
||||
// return mat;
|
||||
// }
|
||||
|
||||
// static inline Rcpp::List matToGrid(RInside &R, const Rcpp::NumericMatrix
|
||||
// &mat,
|
||||
// const Rcpp::NumericMatrix &grid) {
|
||||
// Rcpp::Function pqc_to_grid_R("pqc_to_grid");
|
||||
|
||||
// return pqc_to_grid_R(mat, grid);
|
||||
// }
|
||||
|
||||
static inline std::map<int, std::string>
|
||||
replaceRawKeywordIDs(std::map<int, std::string> raws) {
|
||||
for (auto &raw : raws) {
|
||||
@ -66,26 +28,6 @@ replaceRawKeywordIDs(std::map<int, std::string> raws) {
|
||||
return raws;
|
||||
}
|
||||
|
||||
// static inline uint32_t getSolutionCount(std::unique_ptr<PhreeqcInit>
|
||||
// &phreeqc,
|
||||
// const Rcpp::List &initial_grid) {
|
||||
// PhreeqcInit::ModulesArray mod_array;
|
||||
// Rcpp::Function unique_R("unique");
|
||||
|
||||
// std::vector<int> row_ids =
|
||||
// Rcpp::as<std::vector<int>>(unique_R(initial_grid["ID"]));
|
||||
|
||||
// // std::vector<std::uint32_t> sizes_vec(sizes.begin(), sizes.end());
|
||||
|
||||
// // Rcpp::Function modify_sizes("modify_module_sizes");
|
||||
// // sizes_vec = Rcpp::as<std::vector<std::uint32_t>>(
|
||||
// // modify_sizes(sizes_vec, phreeqc_mat, initial_grid));
|
||||
|
||||
// // std::copy(sizes_vec.begin(), sizes_vec.end(), sizes.begin());
|
||||
|
||||
// return phreeqc->getModuleSizes(row_ids)[POET_SOL];
|
||||
// }
|
||||
|
||||
static std::string readFile(const std::string &path) {
|
||||
std::string string_rpath(PATH_MAX, '\0');
|
||||
|
||||
@ -180,7 +122,30 @@ PhreeqcMatrix InitialList::prepareGrid(const Rcpp::List &grid_input) {
|
||||
throw std::runtime_error("Grid size must be positive.");
|
||||
}
|
||||
|
||||
PhreeqcMatrix pqc_mat = PhreeqcMatrix(database, script);
|
||||
bool with_redox =
|
||||
grid_input.containsElementNamed(
|
||||
GRID_MEMBER_STR(GridMembers::PQC_WITH_REDOX))
|
||||
? Rcpp::as<bool>(
|
||||
grid_input[GRID_MEMBER_STR(GridMembers::PQC_WITH_REDOX)])
|
||||
: false;
|
||||
|
||||
bool with_h0_o0 =
|
||||
grid_input.containsElementNamed(
|
||||
GRID_MEMBER_STR(GridMembers::PQC_WITH_H0_O0))
|
||||
? Rcpp::as<bool>(
|
||||
grid_input[GRID_MEMBER_STR(GridMembers::PQC_WITH_H0_O0)])
|
||||
: false;
|
||||
|
||||
if (with_h0_o0 && !with_redox) {
|
||||
throw std::runtime_error(
|
||||
"Output of H(0) and O(0) can only be used with redox.");
|
||||
}
|
||||
|
||||
this->with_h0_o0 = with_h0_o0;
|
||||
this->with_redox = with_redox;
|
||||
|
||||
PhreeqcMatrix pqc_mat =
|
||||
PhreeqcMatrix(database, script, with_h0_o0, with_redox);
|
||||
|
||||
this->transport_names = pqc_mat.getSolutionNames();
|
||||
|
||||
|
||||
@ -56,6 +56,10 @@ void InitialList::importList(const Rcpp::List &setup, bool minimal) {
|
||||
Rcpp::as<std::string>(setup[static_cast<int>(ExportList::CHEM_DATABASE)]);
|
||||
this->pqc_script = Rcpp::as<std::string>(
|
||||
setup[static_cast<int>(ExportList::CHEM_PQC_SCRIPT)]);
|
||||
this->with_h0_o0 =
|
||||
Rcpp::as<bool>(setup[static_cast<int>(ExportList::CHEM_PQC_WITH_H0_O0)]);
|
||||
this->with_redox =
|
||||
Rcpp::as<bool>(setup[static_cast<int>(ExportList::CHEM_PQC_WITH_REDOX)]);
|
||||
this->field_header = Rcpp::as<std::vector<std::string>>(
|
||||
setup[static_cast<int>(ExportList::CHEM_FIELD_HEADER)]);
|
||||
this->pqc_ids = Rcpp::as<std::vector<int>>(
|
||||
@ -111,6 +115,10 @@ Rcpp::List InitialList::exportList() {
|
||||
out[static_cast<int>(ExportList::CHEM_DATABASE)] = Rcpp::wrap(this->database);
|
||||
out[static_cast<int>(ExportList::CHEM_PQC_SCRIPT)] =
|
||||
Rcpp::wrap(this->pqc_script);
|
||||
out[static_cast<int>(ExportList::CHEM_PQC_WITH_H0_O0)] =
|
||||
Rcpp::wrap(this->with_h0_o0);
|
||||
out[static_cast<int>(ExportList::CHEM_PQC_WITH_REDOX)] =
|
||||
Rcpp::wrap(this->with_redox);
|
||||
out[static_cast<int>(ExportList::CHEM_FIELD_HEADER)] =
|
||||
Rcpp::wrap(this->field_header);
|
||||
out[static_cast<int>(ExportList::CHEM_PQC_IDS)] = Rcpp::wrap(this->pqc_ids);
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
#include <cstdint>
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@ -34,7 +33,7 @@ public:
|
||||
void importList(const Rcpp::List &setup, bool minimal = false);
|
||||
Rcpp::List exportList();
|
||||
|
||||
Field getInitialGrid() const { return Field(this->initial_grid); }
|
||||
Field getInitialGrid() const { return Field(this->initial_grid); }
|
||||
|
||||
private:
|
||||
RInside &R;
|
||||
@ -53,16 +52,10 @@ private:
|
||||
DIFFU_ALPHA_Y,
|
||||
CHEM_DATABASE,
|
||||
CHEM_PQC_SCRIPT,
|
||||
CHEM_PQC_WITH_H0_O0,
|
||||
CHEM_PQC_WITH_REDOX,
|
||||
CHEM_PQC_IDS,
|
||||
CHEM_FIELD_HEADER,
|
||||
// CHEM_PQC_SCRIPTS,
|
||||
// CHEM_PQC_SOLUTIONS,
|
||||
// CHEM_PQC_SOLUTION_PRIMARY,
|
||||
// CHEM_PQC_EXCHANGER,
|
||||
// CHEM_PQC_KINETICS,
|
||||
// CHEM_PQC_EQUILIBRIUM,
|
||||
// CHEM_PQC_SURFACE_COMPS,
|
||||
// CHEM_PQC_SURFACE_CHARGES,
|
||||
CHEM_DHT_SPECIES,
|
||||
CHEM_INTERP_SPECIES,
|
||||
CHEM_HOOKS,
|
||||
@ -76,6 +69,8 @@ private:
|
||||
enum class GridMembers {
|
||||
PQC_SCRIPT_STRING,
|
||||
PQC_SCRIPT_FILE,
|
||||
PQC_WITH_REDOX,
|
||||
PQC_WITH_H0_O0,
|
||||
PQC_DB_STRING,
|
||||
PQC_DB_FILE,
|
||||
GRID_DEF,
|
||||
@ -89,9 +84,10 @@ private:
|
||||
static_cast<std::size_t>(InitialList::GridMembers::ENUM_SIZE);
|
||||
|
||||
static constexpr std::array<const char *, size_GridMembers>
|
||||
GridMembersString = {"pqc_in_string", "pqc_in_file", "pqc_db_string",
|
||||
"pqc_db_file", "grid_def", "grid_size",
|
||||
"constant_cells", "porosity"};
|
||||
GridMembersString = {"pqc_in_string", "pqc_in_file", "pqc_with_redox",
|
||||
"pqc_wth_h0_o0", "pqc_db_string", "pqc_db_file",
|
||||
"grid_def", "grid_size", "constant_cells",
|
||||
"porosity"};
|
||||
|
||||
constexpr const char *GRID_MEMBER_STR(GridMembers member) const {
|
||||
return GridMembersString[static_cast<std::size_t>(member)];
|
||||
@ -190,13 +186,15 @@ private:
|
||||
|
||||
std::string database;
|
||||
std::string pqc_script;
|
||||
bool with_h0_o0{false};
|
||||
bool with_redox{false};
|
||||
// std::vector<std::string> pqc_scripts;
|
||||
std::vector<int> pqc_ids;
|
||||
|
||||
NamedVector<std::uint32_t> dht_species;
|
||||
|
||||
NamedVector<std::uint32_t> interp_species;
|
||||
|
||||
|
||||
// Path to R script that the user defines in the input file
|
||||
std::string ai_surrogate_input_script;
|
||||
|
||||
@ -220,6 +218,8 @@ public:
|
||||
|
||||
std::string database;
|
||||
std::string pqc_script;
|
||||
bool with_h0_o0;
|
||||
bool with_redox;
|
||||
std::vector<int> pqc_ids;
|
||||
|
||||
// std::map<int, std::string> pqc_input;
|
||||
|
||||
@ -34,13 +34,11 @@ int main(int argc, char **argv) {
|
||||
"input script")
|
||||
->default_val(false);
|
||||
|
||||
bool asRDS;
|
||||
app.add_flag("-r, --rds", asRDS, "Save output as .rds")
|
||||
->default_val(false);
|
||||
bool asRDS{false};
|
||||
app.add_flag("-r, --rds", asRDS, "Save output as .rds")->default_val(false);
|
||||
|
||||
bool asQS;
|
||||
app.add_flag("-q, --qs", asQS, "Save output as .qs")
|
||||
->default_val(false);
|
||||
bool asQS{false};
|
||||
app.add_flag("-q, --qs", asQS, "Save output as .qs")->default_val(false);
|
||||
|
||||
CLI11_PARSE(app, argc, argv);
|
||||
|
||||
@ -74,13 +72,13 @@ int main(int argc, char **argv) {
|
||||
|
||||
// append the correct file extension
|
||||
if (asRDS) {
|
||||
output_file += ".rds";
|
||||
output_file += ".rds";
|
||||
} else if (asQS) {
|
||||
output_file += ".qs";
|
||||
output_file += ".qs";
|
||||
} else {
|
||||
output_file += ".qs2";
|
||||
output_file += ".qs2";
|
||||
}
|
||||
|
||||
|
||||
// set working directory to the directory of the input script
|
||||
if (setwd) {
|
||||
const std::string dir_path = Rcpp::as<std::string>(
|
||||
|
||||
69
src/poet.cpp
69
src/poet.cpp
@ -34,10 +34,13 @@
|
||||
#include <Rcpp/DataFrame.h>
|
||||
#include <Rcpp/Function.h>
|
||||
#include <Rcpp/vector/instantiation.h>
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <mpi.h>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include <CLI/CLI.hpp>
|
||||
@ -148,6 +151,9 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) {
|
||||
app.add_flag("--rds", params.as_rds,
|
||||
"Save output as .rds file instead of default .qs2");
|
||||
|
||||
app.add_flag("--qs", params.as_qs,
|
||||
"Save output as .qs file instead of default .qs2");
|
||||
|
||||
app.add_flag("--qs", params.as_qs,
|
||||
"Save output as .qs file instead of default .qs2");
|
||||
|
||||
@ -178,8 +184,10 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) {
|
||||
|
||||
// set the output extension
|
||||
params.out_ext = "qs2";
|
||||
if (params.as_rds) params.out_ext = "rds";
|
||||
if (params.as_qs) params.out_ext = "qs";
|
||||
if (params.as_rds)
|
||||
params.out_ext = "rds";
|
||||
if (params.as_qs)
|
||||
params.out_ext = "qs";
|
||||
|
||||
if (MY_RANK == 0) {
|
||||
// MSG("Complete results storage is " + BOOL_PRINT(simparams.store_result));
|
||||
@ -293,7 +301,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
const double &dt = params.timesteps[iter - 1];
|
||||
|
||||
std::cout << std::endl;
|
||||
|
||||
|
||||
/* displaying iteration number, with C++ and R iterator */
|
||||
MSG("Going through iteration " + std::to_string(iter) + "/" +
|
||||
std::to_string(maxiter));
|
||||
@ -396,7 +404,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
} // END SIMULATION LOOP
|
||||
|
||||
std::cout << std::endl;
|
||||
|
||||
|
||||
Rcpp::List chem_profiling;
|
||||
chem_profiling["simtime"] = chem.GetChemistryTime();
|
||||
chem_profiling["loop"] = chem.GetMasterLoopTime();
|
||||
@ -483,6 +491,54 @@ std::vector<std::string> getSpeciesNames(const Field &&field, int root,
|
||||
return species_names_out;
|
||||
}
|
||||
|
||||
std::array<double, 2> getBaseTotals(Field &&field, int root, MPI_Comm comm) {
|
||||
std::array<double, 2> base_totals;
|
||||
|
||||
int rank;
|
||||
MPI_Comm_rank(comm, &rank);
|
||||
|
||||
const bool is_master = root == rank;
|
||||
|
||||
if (is_master) {
|
||||
const auto h_col = field["H"];
|
||||
const auto o_col = field["O"];
|
||||
|
||||
base_totals[0] = *std::min_element(h_col.begin(), h_col.end());
|
||||
base_totals[1] = *std::min_element(o_col.begin(), o_col.end());
|
||||
MPI_Bcast(base_totals.data(), 2, MPI_DOUBLE, root, MPI_COMM_WORLD);
|
||||
return base_totals;
|
||||
}
|
||||
|
||||
MPI_Bcast(base_totals.data(), 2, MPI_DOUBLE, root, comm);
|
||||
|
||||
return base_totals;
|
||||
}
|
||||
|
||||
bool getHasID(Field &&field, int root, MPI_Comm comm) {
|
||||
bool has_id;
|
||||
|
||||
int rank;
|
||||
MPI_Comm_rank(comm, &rank);
|
||||
|
||||
const bool is_master = root == rank;
|
||||
|
||||
if (is_master) {
|
||||
const auto ID_field = field["ID"];
|
||||
|
||||
std::set<double> unique_IDs(ID_field.begin(), ID_field.end());
|
||||
|
||||
has_id = unique_IDs.size() > 1;
|
||||
|
||||
MPI_Bcast(&has_id, 1, MPI_C_BOOL, root, MPI_COMM_WORLD);
|
||||
|
||||
return has_id;
|
||||
}
|
||||
|
||||
MPI_Bcast(&has_id, 1, MPI_C_BOOL, root, comm);
|
||||
|
||||
return has_id;
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
int world_size;
|
||||
|
||||
@ -529,10 +585,13 @@ int main(int argc, char *argv[]) {
|
||||
init_list.getChemistryInit(), MPI_COMM_WORLD);
|
||||
|
||||
const ChemistryModule::SurrogateSetup surr_setup = {
|
||||
|
||||
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
||||
getBaseTotals(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
||||
getHasID(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
||||
run_params.use_dht,
|
||||
run_params.dht_size,
|
||||
run_params.dht_snaps,
|
||||
run_params.out_dir,
|
||||
run_params.use_interp,
|
||||
run_params.interp_bucket_entries,
|
||||
run_params.interp_size,
|
||||
|
||||
@ -23,7 +23,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -47,27 +46,27 @@ struct RuntimeParameters {
|
||||
|
||||
// MDL added to accomodate for qs::qsave/qread
|
||||
bool as_rds = false;
|
||||
bool as_qs = false;
|
||||
std::string out_ext;
|
||||
bool as_qs = false;
|
||||
std::string out_ext;
|
||||
|
||||
bool print_progress = false;
|
||||
|
||||
static constexpr std::uint32_t WORK_PACKAGE_SIZE_DEFAULT = 32;
|
||||
std::uint32_t work_package_size;
|
||||
std::uint32_t work_package_size = WORK_PACKAGE_SIZE_DEFAULT;
|
||||
|
||||
bool use_dht = false;
|
||||
static constexpr std::uint32_t DHT_SIZE_DEFAULT = 1.5E3;
|
||||
std::uint32_t dht_size;
|
||||
std::uint32_t dht_size = DHT_SIZE_DEFAULT;
|
||||
static constexpr std::uint8_t DHT_SNAPS_DEFAULT = 0;
|
||||
std::uint8_t dht_snaps;
|
||||
std::uint8_t dht_snaps = DHT_SNAPS_DEFAULT;
|
||||
|
||||
bool use_interp = false;
|
||||
static constexpr std::uint32_t INTERP_SIZE_DEFAULT = 100;
|
||||
std::uint32_t interp_size;
|
||||
std::uint32_t interp_size = INTERP_SIZE_DEFAULT;
|
||||
static constexpr std::uint32_t INTERP_MIN_ENTRIES_DEFAULT = 5;
|
||||
std::uint32_t interp_min_entries;
|
||||
std::uint32_t interp_min_entries = INTERP_MIN_ENTRIES_DEFAULT;
|
||||
static constexpr std::uint32_t INTERP_BUCKET_ENTRIES_DEFAULT = 20;
|
||||
std::uint32_t interp_bucket_entries;
|
||||
std::uint32_t interp_bucket_entries = INTERP_BUCKET_ENTRIES_DEFAULT;
|
||||
|
||||
bool use_ai_surrogate = false;
|
||||
};
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user