feat: bcast control_cell_ids to workers

This commit is contained in:
rastogi 2025-11-02 22:26:14 +01:00
parent 7c97f29fa6
commit 1b2d942960
5 changed files with 554 additions and 554 deletions

View File

@ -2,19 +2,16 @@
#ifndef CHEMISTRYMODULE_H_ #ifndef CHEMISTRYMODULE_H_
#define CHEMISTRYMODULE_H_ #define CHEMISTRYMODULE_H_
#include "ChemistryDefs.hpp"
#include "Control/ControlModule.hpp"
#include "DataStructures/Field.hpp" #include "DataStructures/Field.hpp"
#include "DataStructures/NamedVector.hpp" #include "DataStructures/NamedVector.hpp"
#include "ChemistryDefs.hpp"
#include "Init/InitialList.hpp" #include "Init/InitialList.hpp"
#include "NameDouble.h" #include "NameDouble.h"
#include "PhreeqcRunner.hpp"
#include "SurrogateModels/DHT_Wrapper.hpp" #include "SurrogateModels/DHT_Wrapper.hpp"
#include "SurrogateModels/Interpolation.hpp" #include "SurrogateModels/Interpolation.hpp"
#include "poet.hpp"
#include "PhreeqcRunner.hpp"
#include <array> #include <array>
#include <cstdint> #include <cstdint>
#include <map> #include <map>
@ -24,6 +21,7 @@
#include <vector> #include <vector>
namespace poet { namespace poet {
class ControlModule;
/** /**
* \brief Wrapper around PhreeqcRM to provide POET specific parallelization with * \brief Wrapper around PhreeqcRM to provide POET specific parallelization with
* easy access. * easy access.
@ -173,7 +171,7 @@ public:
/** /**
* **Master only** Return the time in seconds the master spent in the * **Master only** Return the time in seconds the master spent in the
* send/receive loop. * send/receive loop.
*/ */
auto GetMasterLoopTime() const { return this->send_recv_t; } auto GetMasterLoopTime() const { return this->send_recv_t; }
auto GetMasterRecvCtrlDataTime() const { return this->recv_ctrl_t; } auto GetMasterRecvCtrlDataTime() const { return this->recv_ctrl_t; }
@ -211,6 +209,8 @@ public:
*/ */
std::vector<double> GetWorkerIdleTimings() const; std::vector<double> GetWorkerIdleTimings() const;
std::vector<double> GetWorkerControlTimings() const;
/** /**
* **Master only** Collect and return DHT hits of all workers. * **Master only** Collect and return DHT hits of all workers.
* *
@ -257,25 +257,15 @@ public:
std::vector<int> ai_surrogate_validity_vector; std::vector<int> ai_surrogate_validity_vector;
RuntimeParameters *runtime_params = nullptr; void SetControlModule(poet::ControlModule *ctrl) { control_module = ctrl; }
uint32_t control_iteration_counter = 0;
struct error_stats { void SetDhtEnabled(bool enabled) { dht_enabled = enabled; }
std::vector<double> mape; bool GetDhtEnabled() const { return dht_enabled; }
std::vector<double> rrsme;
uint32_t iteration;
error_stats(size_t species_count, size_t iter) void SetInterpEnabled(bool enabled) { interp_enabled = enabled; }
: mape(species_count, 0.0), rrsme(species_count, 0.0), iteration(iter) { bool GetInterpEnabled() const { return interp_enabled; }
}
};
std::vector<error_stats> error_stats_history; void SetWarmupEnabled(bool enabled) { warmup_enabled = enabled; }
static void computeStats(const std::vector<double> &pqc_vector,
const std::vector<double> &sur_vector,
uint32_t size_per_prop, uint32_t species_count,
error_stats &stats);
protected: protected:
void initializeDHT(uint32_t size_mb, void initializeDHT(uint32_t size_mb,
@ -290,12 +280,13 @@ protected:
enum { enum {
CHEM_FIELD_INIT, CHEM_FIELD_INIT,
CHEM_DHT_ENABLE, //CHEM_DHT_ENABLE,
CHEM_DHT_SIGNIF_VEC, CHEM_DHT_SIGNIF_VEC,
CHEM_DHT_SNAPS, CHEM_DHT_SNAPS,
CHEM_DHT_READ_FILE, CHEM_DHT_READ_FILE,
CHEM_INTERP, //CHEM_WARMUP_PHASE, // Control flag
CHEM_IP_ENABLE, //CHEM_CTRL_ENABLE, // Control flag
//CHEM_IP_ENABLE,
CHEM_IP_MIN_ENTRIES, CHEM_IP_MIN_ENTRIES,
CHEM_IP_SIGNIF_VEC, CHEM_IP_SIGNIF_VEC,
CHEM_WORK_LOOP, CHEM_WORK_LOOP,
@ -308,6 +299,7 @@ protected:
enum { enum {
WORKER_PHREEQC, WORKER_PHREEQC,
WORKER_CTRL_ITER,
WORKER_DHT_GET, WORKER_DHT_GET,
WORKER_DHT_FILL, WORKER_DHT_FILL,
WORKER_IDLE, WORKER_IDLE,
@ -330,6 +322,7 @@ protected:
double dht_get = 0.; double dht_get = 0.;
double dht_fill = 0.; double dht_fill = 0.;
double idle_t = 0.; double idle_t = 0.;
double ctrl_t = 0.;
}; };
struct worker_info_s { struct worker_info_s {
@ -347,7 +340,7 @@ protected:
void MasterSendPkgs(worker_list_t &w_list, workpointer_t &work_pointer, void MasterSendPkgs(worker_list_t &w_list, workpointer_t &work_pointer,
workpointer_t &sur_pointer, int &pkg_to_send, workpointer_t &sur_pointer, int &pkg_to_send,
int &count_pkgs, int &free_workers, double dt, int &count_pkgs, int &free_workers, double dt,
uint32_t iteration, uint32_t control_iteration, uint32_t iteration,
const std::vector<uint32_t> &wp_sizes_vector); const std::vector<uint32_t> &wp_sizes_vector);
void MasterRecvPkgs(worker_list_t &w_list, int &pkg_to_recv, bool to_send, void MasterRecvPkgs(worker_list_t &w_list, int &pkg_to_recv, bool to_send,
int &free_workers); int &free_workers);
@ -385,6 +378,10 @@ protected:
void BCastStringVec(std::vector<std::string> &io); void BCastStringVec(std::vector<std::string> &io);
int packResultsIntoBuffer(std::vector<double> &mpi_buffer, int base_count,
const WorkPackage &wp,
const WorkPackage &wp_control);
int comm_size, comm_rank; int comm_size, comm_rank;
MPI_Comm group_comm; MPI_Comm group_comm;
@ -412,6 +409,7 @@ protected:
inline void PropagateFunctionType(int &type) const { inline void PropagateFunctionType(int &type) const {
ChemBCast(&type, 1, MPI_INT); ChemBCast(&type, 1, MPI_INT);
} }
double simtime = 0.; double simtime = 0.;
double idle_t = 0.; double idle_t = 0.;
double seq_t = 0.; double seq_t = 0.;
@ -419,10 +417,9 @@ protected:
double recv_ctrl_t = 0.; double recv_ctrl_t = 0.;
double shuf_t = 0.; double shuf_t = 0.;
double metrics_t = 0. double metrics_t = 0.;
std::array<double, 2> std::array<double, 2> base_totals{0};
base_totals{0};
bool print_progessbar{false}; bool print_progessbar{false};
@ -442,8 +439,12 @@ protected:
poet::ControlModule *control_module = nullptr; poet::ControlModule *control_module = nullptr;
std::vector<double> mpi_surr_buffer;
bool control_enabled{false}; bool control_enabled{false};
bool warmup_enabled{false}; bool warmup_enabled{false};
// std::vector<double> sur_shuffled;
}; };
} // namespace poet } // namespace poet

View File

@ -3,7 +3,6 @@
#include <algorithm> #include <algorithm>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <iomanip>
#include <mpi.h> #include <mpi.h>
#include <vector> #include <vector>
@ -41,6 +40,12 @@ std::vector<double> poet::ChemistryModule::GetWorkerPhreeqcTimings() const {
return MasterGatherWorkerTimings(WORKER_PHREEQC); return MasterGatherWorkerTimings(WORKER_PHREEQC);
} }
std::vector<double> poet::ChemistryModule::GetWorkerControlTimings() const {
int type = CHEM_PERF;
MPI_Bcast(&type, 1, MPI_INT, 0, this->group_comm);
return MasterGatherWorkerTimings(WORKER_CTRL_ITER);
}
std::vector<double> poet::ChemistryModule::GetWorkerDHTGetTimings() const { std::vector<double> poet::ChemistryModule::GetWorkerDHTGetTimings() const {
int type = CHEM_PERF; int type = CHEM_PERF;
MPI_Bcast(&type, 1, MPI_INT, 0, this->group_comm); MPI_Bcast(&type, 1, MPI_INT, 0, this->group_comm);
@ -252,6 +257,8 @@ inline void poet::ChemistryModule::MasterSendPkgs(
/* note current processed work package in workerlist */ /* note current processed work package in workerlist */
w_list[p].send_addr = work_pointer.base(); w_list[p].send_addr = work_pointer.base();
w_list[p].surrogate_addr = sur_pointer.base(); w_list[p].surrogate_addr = sur_pointer.base();
// this->control_enabled ? sur_pointer.base() : w_list[p].surrogate_addr =
// nullptr;
/* push work pointer to next work package */ /* push work pointer to next work package */
const uint32_t end_of_wp = local_work_package_size * this->prop_count; const uint32_t end_of_wp = local_work_package_size * this->prop_count;
@ -349,6 +356,11 @@ inline void poet::ChemistryModule::MasterRecvPkgs(worker_list_t &w_list,
std::copy(recv_buffer.begin(), recv_buffer.begin() + half, std::copy(recv_buffer.begin(), recv_buffer.begin() + half,
w_list[p - 1].send_addr); w_list[p - 1].send_addr);
/*
if (w_list[p - 1].surrogate_addr == nullptr) {
throw std::runtime_error("MasterRecvPkgs: surrogate_addr is null");
}*/
std::copy(recv_buffer.begin() + (size / 2), recv_buffer.begin() + size, std::copy(recv_buffer.begin() + (size / 2), recv_buffer.begin() + size,
w_list[p - 1].surrogate_addr); w_list[p - 1].surrogate_addr);
recv_ctrl_b = MPI_Wtime(); recv_ctrl_b = MPI_Wtime();
@ -418,6 +430,7 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
int free_workers; int free_workers;
int i_pkgs; int i_pkgs;
int ftype; int ftype;
double shuf_a, shuf_b, metrics_a, metrics_b;
const std::vector<uint32_t> wp_sizes_vector = const std::vector<uint32_t> wp_sizes_vector =
CalculateWPSizesVector(this->n_cells, this->wp_size); CalculateWPSizesVector(this->n_cells, this->wp_size);
@ -435,47 +448,34 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
ftype = CHEM_WORK_LOOP; ftype = CHEM_WORK_LOOP;
PropagateFunctionType(ftype); PropagateFunctionType(ftype);
ftype = CHEM_INTERP;
PropagateFunctionType(ftype);
if (this->runtime_params->rollback_simulation) {
this->interp_enabled = false;
int interp_flag = 0;
ChemBCast(&interp_flag, 1, MPI_INT);
} else {
this->interp_enabled = true;
int interp_flag = 1;
ChemBCast(&interp_flag, 1, MPI_INT);
}
MPI_Barrier(this->group_comm); MPI_Barrier(this->group_comm);
static uint32_t iteration = 0; this->control_enabled = this->control_module->getControlIntervalEnabled();
uint32_t control_iteration = static_cast<uint32_t>( if (this->control_enabled) {
this->runtime_params->control_iteration_active ? 1 : 0); this->mpi_surr_buffer.assign(this->n_cells * this->prop_count, 0.0);
if (control_iteration) {
sur_shuffled.clear();
sur_shuffled.reserve(this->n_cells * this->prop_count);
} }
static uint32_t iteration = 0;
/* start time measurement of sequential part */ /* start time measurement of sequential part */
seq_a = MPI_Wtime(); seq_a = MPI_Wtime();
/* shuffle grid */ /* shuffle grid */
// grid.shuffleAndExport(mpi_buffer); // grid.shuffleAndExport(mpi_buffer);
std::vector<double> mpi_buffer = std::vector<double> mpi_buffer =
shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count, shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count,
wp_sizes_vector.size()); wp_sizes_vector.size());
this->sur_shuffled.resize(mpi_buffer.size()); //this->mpi_surr_buffer.resize(mpi_buffer.size());
/* setup local variables */ /* setup local variables */
pkg_to_send = wp_sizes_vector.size(); pkg_to_send = wp_sizes_vector.size();
pkg_to_recv = wp_sizes_vector.size(); pkg_to_recv = wp_sizes_vector.size();
workpointer_t work_pointer = mpi_buffer.begin(); workpointer_t work_pointer = mpi_buffer.begin();
workpointer_t sur_pointer = sur_shuffled.begin(); workpointer_t sur_pointer = this->mpi_surr_buffer.begin();
//(this->control_enabled ? this->mpi_surr_buffer.begin()
// : mpi_buffer.end());
worker_list_t worker_list(this->comm_size - 1); worker_list_t worker_list(this->comm_size - 1);
free_workers = this->comm_size - 1; free_workers = this->comm_size - 1;
@ -499,8 +499,7 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
if (pkg_to_send > 0) { if (pkg_to_send > 0) {
// send packages to all free workers ... // send packages to all free workers ...
MasterSendPkgs(worker_list, work_pointer, sur_pointer, pkg_to_send, MasterSendPkgs(worker_list, work_pointer, sur_pointer, pkg_to_send,
i_pkgs, free_workers, dt, iteration, control_iteration, i_pkgs, free_workers, dt, iteration, wp_sizes_vector);
wp_sizes_vector);
} }
// ... and try to receive them from workers who has finished their work // ... and try to receive them from workers who has finished their work
MasterRecvPkgs(worker_list, pkg_to_recv, pkg_to_send > 0, free_workers); MasterRecvPkgs(worker_list, pkg_to_recv, pkg_to_send > 0, free_workers);
@ -524,15 +523,13 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
chem_field = out_vec; chem_field = out_vec;
/* do master stuff */ /* do master stuff */
if (this->control_enabled) {
/* do master stuff */
if (control_enabled) {
std::cout << "[Master] Control logic enabled for this iteration." std::cout << "[Master] Control logic enabled for this iteration."
<< std::endl; << std::endl;
std::vector<double> sur_unshuffled{mpi_surr_buffer}; std::vector<double> sur_unshuffled{mpi_surr_buffer};
shuf_a = MPI_Wtime(); shuf_a = MPI_Wtime();
unshuffleField(mpi_surr_buffer, this->n_cells, this->prop_count, unshuffleField(this->mpi_surr_buffer, this->n_cells, this->prop_count,
wp_sizes_vector.size(), sur_unshuffled); wp_sizes_vector.size(), sur_unshuffled);
shuf_b = MPI_Wtime(); shuf_b = MPI_Wtime();
this->shuf_t += shuf_b - shuf_a; this->shuf_t += shuf_b - shuf_a;
@ -550,7 +547,6 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
this->metrics_t += metrics_b - metrics_a; this->metrics_t += metrics_b - metrics_a;
} }
/* start time measurement of master chemistry */ /* start time measurement of master chemistry */
sim_e_chemistry = MPI_Wtime(); sim_e_chemistry = MPI_Wtime();

File diff suppressed because it is too large Load Diff

View File

@ -52,7 +52,7 @@ public:
std::uint32_t control_interval; std::uint32_t control_interval;
std::vector<std::string> species_names; std::vector<std::string> species_names;
std::vector<double> mape_threshold; std::vector<double> mape_threshold;
std::vector<double> ctrl_cell_ids; std::vector<uint32_t> ctrl_cell_ids;
}; };
void enableControlLogic(const ControlSetup &setup) { void enableControlLogic(const ControlSetup &setup) {

View File

@ -250,12 +250,13 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
params.timesteps = params.timesteps =
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("timesteps")); Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("timesteps"));
params.checkpoint_interval = Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval")); params.checkpoint_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval"));
params.control_interval = params.control_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("control_interval")); Rcpp::as<uint32_t>(global_rt_setup->operator[]("control_interval"));
params.mape_threshold = Rcpp::as<std::vector<double>>( params.mape_threshold = Rcpp::as<std::vector<double>>(
global_rt_setup->operator[]("mape_threshold")); global_rt_setup->operator[]("mape_threshold"));
params.ctrl_cell_ids = Rcpp::as<std::vector<uint32_t>>( params.ctrl_cell_ids = Rcpp::as<std::vector<uint32_t>>(
global_rt_setup->operator[]("ctrl_cell_ids")); global_rt_setup->operator[]("ctrl_cell_ids"));
catch (const std::exception &e) { catch (const std::exception &e) {
@ -465,6 +466,30 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
return profiling; return profiling;
} }
static void getControlCellIds(const vector<uint32_t> &ids, int root,
MPI_Comm comm) {
std::uint32_t n_ids = 0;
int rank;
MPI_Comm_rank(comm, &rank);
bool is_master = root == rank;
if (is_master) {
n_ids = ids.size();
}
// broadcast size of id vector
MPI_Bcast(n_ids, 1, MPI_UINT32_T, root, comm);
// worker
if (!is_master) {
ids.resize(n_ids);
}
// broadcast control cell ids
if (n_ids > 0) {
MPI_Bcast(ids.data(), n_ids, MPI_UINT32_T, root, comm);
}
}
std::vector<std::string> getSpeciesNames(const Field &&field, int root, std::vector<std::string> getSpeciesNames(const Field &&field, int root,
MPI_Comm comm) { MPI_Comm comm) {
std::uint32_t n_elements; std::uint32_t n_elements;