feat: bcast control_cell_ids to workers

This commit is contained in:
rastogi 2025-11-02 22:26:14 +01:00
parent 9810221182
commit d0faa8e6c3
5 changed files with 554 additions and 554 deletions

View File

@ -2,19 +2,16 @@
#ifndef CHEMISTRYMODULE_H_
#define CHEMISTRYMODULE_H_
#include "ChemistryDefs.hpp"
#include "Control/ControlModule.hpp"
#include "DataStructures/Field.hpp"
#include "DataStructures/NamedVector.hpp"
#include "ChemistryDefs.hpp"
#include "Init/InitialList.hpp"
#include "NameDouble.h"
#include "PhreeqcRunner.hpp"
#include "SurrogateModels/DHT_Wrapper.hpp"
#include "SurrogateModels/Interpolation.hpp"
#include "poet.hpp"
#include "PhreeqcRunner.hpp"
#include <array>
#include <cstdint>
#include <map>
@ -24,6 +21,7 @@
#include <vector>
namespace poet {
class ControlModule;
/**
* \brief Wrapper around PhreeqcRM to provide POET specific parallelization with
* easy access.
@ -173,7 +171,7 @@ public:
/**
* **Master only** Return the time in seconds the master spent in the
* send/receive loop.
*/
*/
auto GetMasterLoopTime() const { return this->send_recv_t; }
auto GetMasterRecvCtrlDataTime() const { return this->recv_ctrl_t; }
@ -211,6 +209,8 @@ public:
*/
std::vector<double> GetWorkerIdleTimings() const;
std::vector<double> GetWorkerControlTimings() const;
/**
* **Master only** Collect and return DHT hits of all workers.
*
@ -257,25 +257,15 @@ public:
std::vector<int> ai_surrogate_validity_vector;
RuntimeParameters *runtime_params = nullptr;
uint32_t control_iteration_counter = 0;
void SetControlModule(poet::ControlModule *ctrl) { control_module = ctrl; }
struct error_stats {
std::vector<double> mape;
std::vector<double> rrsme;
uint32_t iteration;
void SetDhtEnabled(bool enabled) { dht_enabled = enabled; }
bool GetDhtEnabled() const { return dht_enabled; }
error_stats(size_t species_count, size_t iter)
: mape(species_count, 0.0), rrsme(species_count, 0.0), iteration(iter) {
}
};
void SetInterpEnabled(bool enabled) { interp_enabled = enabled; }
bool GetInterpEnabled() const { return interp_enabled; }
std::vector<error_stats> error_stats_history;
static void computeStats(const std::vector<double> &pqc_vector,
const std::vector<double> &sur_vector,
uint32_t size_per_prop, uint32_t species_count,
error_stats &stats);
void SetWarmupEnabled(bool enabled) { warmup_enabled = enabled; }
protected:
void initializeDHT(uint32_t size_mb,
@ -290,12 +280,13 @@ protected:
enum {
CHEM_FIELD_INIT,
CHEM_DHT_ENABLE,
//CHEM_DHT_ENABLE,
CHEM_DHT_SIGNIF_VEC,
CHEM_DHT_SNAPS,
CHEM_DHT_READ_FILE,
CHEM_INTERP,
CHEM_IP_ENABLE,
//CHEM_WARMUP_PHASE, // Control flag
//CHEM_CTRL_ENABLE, // Control flag
//CHEM_IP_ENABLE,
CHEM_IP_MIN_ENTRIES,
CHEM_IP_SIGNIF_VEC,
CHEM_WORK_LOOP,
@ -308,6 +299,7 @@ protected:
enum {
WORKER_PHREEQC,
WORKER_CTRL_ITER,
WORKER_DHT_GET,
WORKER_DHT_FILL,
WORKER_IDLE,
@ -330,6 +322,7 @@ protected:
double dht_get = 0.;
double dht_fill = 0.;
double idle_t = 0.;
double ctrl_t = 0.;
};
struct worker_info_s {
@ -347,7 +340,7 @@ protected:
void MasterSendPkgs(worker_list_t &w_list, workpointer_t &work_pointer,
workpointer_t &sur_pointer, int &pkg_to_send,
int &count_pkgs, int &free_workers, double dt,
uint32_t iteration, uint32_t control_iteration,
uint32_t iteration,
const std::vector<uint32_t> &wp_sizes_vector);
void MasterRecvPkgs(worker_list_t &w_list, int &pkg_to_recv, bool to_send,
int &free_workers);
@ -385,6 +378,10 @@ protected:
void BCastStringVec(std::vector<std::string> &io);
int packResultsIntoBuffer(std::vector<double> &mpi_buffer, int base_count,
const WorkPackage &wp,
const WorkPackage &wp_control);
int comm_size, comm_rank;
MPI_Comm group_comm;
@ -412,6 +409,7 @@ protected:
inline void PropagateFunctionType(int &type) const {
ChemBCast(&type, 1, MPI_INT);
}
double simtime = 0.;
double idle_t = 0.;
double seq_t = 0.;
@ -419,10 +417,9 @@ protected:
double recv_ctrl_t = 0.;
double shuf_t = 0.;
double metrics_t = 0.
double metrics_t = 0.;
std::array<double, 2>
base_totals{0};
std::array<double, 2> base_totals{0};
bool print_progessbar{false};
@ -442,8 +439,12 @@ protected:
poet::ControlModule *control_module = nullptr;
std::vector<double> mpi_surr_buffer;
bool control_enabled{false};
bool warmup_enabled{false};
// std::vector<double> sur_shuffled;
};
} // namespace poet

View File

@ -3,7 +3,6 @@
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <iomanip>
#include <mpi.h>
#include <vector>
@ -41,6 +40,12 @@ std::vector<double> poet::ChemistryModule::GetWorkerPhreeqcTimings() const {
return MasterGatherWorkerTimings(WORKER_PHREEQC);
}
std::vector<double> poet::ChemistryModule::GetWorkerControlTimings() const {
int type = CHEM_PERF;
MPI_Bcast(&type, 1, MPI_INT, 0, this->group_comm);
return MasterGatherWorkerTimings(WORKER_CTRL_ITER);
}
std::vector<double> poet::ChemistryModule::GetWorkerDHTGetTimings() const {
int type = CHEM_PERF;
MPI_Bcast(&type, 1, MPI_INT, 0, this->group_comm);
@ -252,6 +257,8 @@ inline void poet::ChemistryModule::MasterSendPkgs(
/* note current processed work package in workerlist */
w_list[p].send_addr = work_pointer.base();
w_list[p].surrogate_addr = sur_pointer.base();
// this->control_enabled ? sur_pointer.base() : w_list[p].surrogate_addr =
// nullptr;
/* push work pointer to next work package */
const uint32_t end_of_wp = local_work_package_size * this->prop_count;
@ -349,6 +356,11 @@ inline void poet::ChemistryModule::MasterRecvPkgs(worker_list_t &w_list,
std::copy(recv_buffer.begin(), recv_buffer.begin() + half,
w_list[p - 1].send_addr);
/*
if (w_list[p - 1].surrogate_addr == nullptr) {
throw std::runtime_error("MasterRecvPkgs: surrogate_addr is null");
}*/
std::copy(recv_buffer.begin() + (size / 2), recv_buffer.begin() + size,
w_list[p - 1].surrogate_addr);
recv_ctrl_b = MPI_Wtime();
@ -418,6 +430,7 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
int free_workers;
int i_pkgs;
int ftype;
double shuf_a, shuf_b, metrics_a, metrics_b;
const std::vector<uint32_t> wp_sizes_vector =
CalculateWPSizesVector(this->n_cells, this->wp_size);
@ -435,47 +448,34 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
ftype = CHEM_WORK_LOOP;
PropagateFunctionType(ftype);
ftype = CHEM_INTERP;
PropagateFunctionType(ftype);
if (this->runtime_params->rollback_simulation) {
this->interp_enabled = false;
int interp_flag = 0;
ChemBCast(&interp_flag, 1, MPI_INT);
} else {
this->interp_enabled = true;
int interp_flag = 1;
ChemBCast(&interp_flag, 1, MPI_INT);
}
MPI_Barrier(this->group_comm);
static uint32_t iteration = 0;
uint32_t control_iteration = static_cast<uint32_t>(
this->runtime_params->control_iteration_active ? 1 : 0);
if (control_iteration) {
sur_shuffled.clear();
sur_shuffled.reserve(this->n_cells * this->prop_count);
this->control_enabled = this->control_module->getControlIntervalEnabled();
if (this->control_enabled) {
this->mpi_surr_buffer.assign(this->n_cells * this->prop_count, 0.0);
}
static uint32_t iteration = 0;
/* start time measurement of sequential part */
seq_a = MPI_Wtime();
/* shuffle grid */
// grid.shuffleAndExport(mpi_buffer);
std::vector<double> mpi_buffer =
shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count,
wp_sizes_vector.size());
this->sur_shuffled.resize(mpi_buffer.size());
//this->mpi_surr_buffer.resize(mpi_buffer.size());
/* setup local variables */
pkg_to_send = wp_sizes_vector.size();
pkg_to_recv = wp_sizes_vector.size();
workpointer_t work_pointer = mpi_buffer.begin();
workpointer_t sur_pointer = sur_shuffled.begin();
workpointer_t sur_pointer = this->mpi_surr_buffer.begin();
//(this->control_enabled ? this->mpi_surr_buffer.begin()
// : mpi_buffer.end());
worker_list_t worker_list(this->comm_size - 1);
free_workers = this->comm_size - 1;
@ -499,8 +499,7 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
if (pkg_to_send > 0) {
// send packages to all free workers ...
MasterSendPkgs(worker_list, work_pointer, sur_pointer, pkg_to_send,
i_pkgs, free_workers, dt, iteration, control_iteration,
wp_sizes_vector);
i_pkgs, free_workers, dt, iteration, wp_sizes_vector);
}
// ... and try to receive them from workers who has finished their work
MasterRecvPkgs(worker_list, pkg_to_recv, pkg_to_send > 0, free_workers);
@ -524,15 +523,13 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
chem_field = out_vec;
/* do master stuff */
/* do master stuff */
if (control_enabled) {
if (this->control_enabled) {
std::cout << "[Master] Control logic enabled for this iteration."
<< std::endl;
std::vector<double> sur_unshuffled{mpi_surr_buffer};
shuf_a = MPI_Wtime();
unshuffleField(mpi_surr_buffer, this->n_cells, this->prop_count,
unshuffleField(this->mpi_surr_buffer, this->n_cells, this->prop_count,
wp_sizes_vector.size(), sur_unshuffled);
shuf_b = MPI_Wtime();
this->shuf_t += shuf_b - shuf_a;
@ -550,7 +547,6 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
this->metrics_t += metrics_b - metrics_a;
}
/* start time measurement of master chemistry */
sim_e_chemistry = MPI_Wtime();

File diff suppressed because it is too large Load Diff

View File

@ -52,7 +52,7 @@ public:
std::uint32_t control_interval;
std::vector<std::string> species_names;
std::vector<double> mape_threshold;
std::vector<double> ctrl_cell_ids;
std::vector<uint32_t> ctrl_cell_ids;
};
void enableControlLogic(const ControlSetup &setup) {

View File

@ -250,12 +250,13 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
params.timesteps =
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("timesteps"));
params.checkpoint_interval = Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval"));
params.checkpoint_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval"));
params.control_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("control_interval"));
params.mape_threshold = Rcpp::as<std::vector<double>>(
global_rt_setup->operator[]("mape_threshold"));
params.ctrl_cell_ids = Rcpp::as<std::vector<uint32_t>>(
params.ctrl_cell_ids = Rcpp::as<std::vector<uint32_t>>(
global_rt_setup->operator[]("ctrl_cell_ids"));
catch (const std::exception &e) {
@ -465,6 +466,30 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
return profiling;
}
static void getControlCellIds(const vector<uint32_t> &ids, int root,
MPI_Comm comm) {
std::uint32_t n_ids = 0;
int rank;
MPI_Comm_rank(comm, &rank);
bool is_master = root == rank;
if (is_master) {
n_ids = ids.size();
}
// broadcast size of id vector
MPI_Bcast(n_ids, 1, MPI_UINT32_T, root, comm);
// worker
if (!is_master) {
ids.resize(n_ids);
}
// broadcast control cell ids
if (n_ids > 0) {
MPI_Bcast(ids.data(), n_ids, MPI_UINT32_T, root, comm);
}
}
std::vector<std::string> getSpeciesNames(const Field &&field, int root,
MPI_Comm comm) {
std::uint32_t n_elements;