[wip] fix: add base_totals to SurrogateSetup

This commit is contained in:
Max Luebke 2024-12-12 20:16:02 +01:00 committed by Max Lübke
parent dd9cc5e59f
commit 43d2a846c7
3 changed files with 45 additions and 13 deletions

View File

@ -309,9 +309,10 @@ void poet::ChemistryModule::initializeInterp(
map_copy = this->dht->getKeySpecies();
for (auto i = 0; i < map_copy.size(); i++) {
const std::uint32_t signif =
static_cast<std::uint32_t>(map_copy[i]) - (map_copy[i] > InterpolationModule::COARSE_DIFF
? InterpolationModule::COARSE_DIFF
: 0);
static_cast<std::uint32_t>(map_copy[i]) -
(map_copy[i] > InterpolationModule::COARSE_DIFF
? InterpolationModule::COARSE_DIFF
: 0);
map_copy[i] = signif;
}
}
@ -368,7 +369,8 @@ void poet::ChemistryModule::unshuffleField(const std::vector<double> &in_buffer,
}
}
}
void poet::ChemistryModule::set_ai_surrogate_validity_vector(std::vector<int> r_vector) {
void poet::ChemistryModule::set_ai_surrogate_validity_vector(
std::vector<int> r_vector) {
this->ai_surrogate_validity_vector = r_vector;
}

View File

@ -76,6 +76,7 @@ public:
struct SurrogateSetup {
std::vector<std::string> prop_names;
std::array<double, 2> base_totals;
bool dht_enabled;
std::uint32_t dht_size_mb;
@ -96,6 +97,8 @@ public:
this->interp_enabled = setup.interp_enabled;
this->ai_surrogate_enabled = setup.ai_surrogate_enabled;
this->base_totals = setup.base_totals;
if (this->dht_enabled || this->interp_enabled) {
this->initializeDHT(setup.dht_size_mb, this->params.dht_species);
}
@ -223,8 +226,8 @@ public:
};
/**
* **Master only** Set the ai surrogate validity vector from R
*/
* **Master only** Set the ai surrogate validity vector from R
*/
void set_ai_surrogate_validity_vector(std::vector<int> r_vector);
std::vector<uint32_t> GetWorkerInterpolationCalls() const;

View File

@ -34,6 +34,8 @@
#include <Rcpp/DataFrame.h>
#include <Rcpp/Function.h>
#include <Rcpp/vector/instantiation.h>
#include <algorithm>
#include <array>
#include <cstdint>
#include <cstdlib>
#include <memory>
@ -150,7 +152,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
app.add_flag("--qs", params.as_qs,
"Save output as .qs file instead of default .qs2");
std::string init_file;
std::string runtime_file;
@ -178,8 +180,10 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
// set the output extension
params.out_ext = "qs2";
if (params.as_rds) params.out_ext = "rds";
if (params.as_qs) params.out_ext = "qs";
if (params.as_rds)
params.out_ext = "rds";
if (params.as_qs)
params.out_ext = "qs";
if (MY_RANK == 0) {
// MSG("Complete results storage is " + BOOL_PRINT(simparams.store_result));
@ -293,7 +297,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
const double &dt = params.timesteps[iter - 1];
std::cout << std::endl;
/* displaying iteration number, with C++ and R iterator */
MSG("Going through iteration " + std::to_string(iter) + "/" +
std::to_string(maxiter));
@ -396,7 +400,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
} // END SIMULATION LOOP
std::cout << std::endl;
Rcpp::List chem_profiling;
chem_profiling["simtime"] = chem.GetChemistryTime();
chem_profiling["loop"] = chem.GetMasterLoopTime();
@ -483,6 +487,29 @@ std::vector<std::string> getSpeciesNames(const Field &&field, int root,
return species_names_out;
}
std::array<double, 2> getBaseTotals(Field &&field, int root, MPI_Comm comm) {
std::array<double, 2> base_totals;
int rank;
MPI_Comm_rank(comm, &rank);
const bool is_master = root == rank;
if (is_master) {
const auto h_col = field["H"];
const auto o_col = field["O"];
base_totals[0] = *std::min_element(h_col.begin(), h_col.end());
base_totals[1] = *std::min_element(o_col.begin(), o_col.end());
MPI_Bcast(base_totals.data(), 2, MPI_DOUBLE, root, MPI_COMM_WORLD);
return base_totals;
}
MPI_Bcast(base_totals.data(), 2, MPI_DOUBLE, root, comm);
return base_totals;
}
int main(int argc, char *argv[]) {
int world_size;
@ -529,8 +556,8 @@ int main(int argc, char *argv[]) {
init_list.getChemistryInit(), MPI_COMM_WORLD);
const ChemistryModule::SurrogateSetup surr_setup = {
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
getBaseTotals(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
run_params.use_dht,
run_params.dht_size,
run_params.use_interp,