Add exclusion cond. for Cl

This commit is contained in:
rastogi 2025-11-30 17:58:13 +01:00
parent 1de30ad0db
commit c637f5c787
16 changed files with 533 additions and 438 deletions

Binary file not shown.

View File

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
#SBATCH --job-name=proto2_eps0035 #SBATCH --job-name=p2_eps01_no_skip
#SBATCH --output=proto2_eps0035_no_rb_v2_%j.out #SBATCH --output=p2_eps01_no_skip_%j.out
#SBATCH --error=proto2_eps0035_no_rb_v2%j.err #SBATCH --error=p2_eps01_no_skip_%j.err
#SBATCH --partition=long #SBATCH --partition=long
#SBATCH --nodes=6 #SBATCH --nodes=6
#SBATCH --ntasks-per-node=24 #SBATCH --ntasks-per-node=24
@ -15,5 +15,5 @@ module purge
module load cmake gcc openmpi module load cmake gcc openmpi
#mpirun -n 144 ./poet dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_only_pqc #mpirun -n 144 ./poet dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_only_pqc
mpirun -n 144 ./poet --interp --rds dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 proto2_eps0035 mpirun -n 144 ./poet --interp --rds dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 p2_eps01_no_skip
#mpirun -n 144 ./poet --interp barite_fgcs_4_new/barite_fgcs_4_new_rt.R barite_fgcs_4_new/barite_fgcs_4_new.qs2 barite #mpirun -n 144 ./poet --interp barite_fgcs_4_new/barite_fgcs_4_new_rt.R barite_fgcs_4_new/barite_fgcs_4_new.qs2 barite

165
docs/class_diagram.md Normal file
View File

@ -0,0 +1,165 @@
# POET Class Diagram
```mermaid
classDiagram
class RuntimeParameters {
+bool print_progress
+uint32_t work_package_size
+bool use_dht
+uint32_t dht_size
+uint32_t dht_snaps
+bool use_interp
+uint32_t interp_size
+uint32_t interp_min_entries
+uint32_t interp_bucket_entries
+bool use_ai_surrogate
+bool as_rds
+bool as_qs
+string out_ext
+string out_dir
+vector~double~ timesteps
+uint32_t checkpoint_interval
+uint32_t stab_interval
+double zero_abs
+vector~double~ mape_threshold
+vector~uint32_t~ ctrl_cell_ids
+Rcpp::List init_params
}
class Field {
+GetProps() vector~string~
+AsVector() vector~double~
+GetRequestedVecSize() size_t
+update(Field) void
+asSEXP() SEXP
+operator[](string) vector~double~
}
class InitialList {
-RInside& R
+InitialList(RInside&)
+importList(Rcpp::List, bool) void
+getChemistryInit() ChemistryInit
+getDiffusionInit() DiffusionInit
+getInitialGrid() Field
}
class ChemistryModule {
+ChemistryModule(uint32_t, ChemistryInit, MPI_Comm)
+simulate(double) void
+getField() Field&
+WorkerLoop() void
+masterSetField(Field) void
+masterEnableSurrogates(SurrogateSetup) void
+SetControlCellIds(vector~uint32_t~) void
+SetControlModule(ControlModule*) void
+setProgressBarPrintout(bool) void
+set_ai_surrogate_validity_vector(SEXP) void
+MasterLoopBreak() void
+GetChemistryTime() double
+GetMasterLoopTime() double
+GetWorkerIdleTimings() vector~double~
+GetWorkerPhreeqcTimings() vector~double~
+GetWorkerDHTHits() vector~uint64_t~
+GetWorkerDHTEvictions() vector~uint64_t~
-Field field
-uint32_t work_package_size
-MPI_Comm comm
}
class DiffusionModule {
+DiffusionModule(DiffusionInit, Field)
+simulate(double) void
+getField() Field&
+getTransportTime() double
-Field field
}
class RInsidePOET {
+getInstance()$ RInsidePOET&
+parseEval(string) SEXP
+parseEvalQ(string) void
+operator[](string) Proxy
}
class ChemistryInit {
+dht_species SpeciesList
+ai_surrogate_input_script string
}
class DiffusionInit {
}
class SurrogateSetup {
+vector~string~ species_names
+array~double,2~ base_totals
+bool has_id
+bool use_dht
+uint32_t dht_size
+uint32_t dht_snaps
+string out_dir
+bool use_interp
+uint32_t interp_bucket_entries
+uint32_t interp_size
+uint32_t interp_min_entries
+bool use_ai_surrogate
}
class Main {
+main(int, char**) int
-parseInitValues(int, char**, RuntimeParameters&) int
-init_global_functions(RInside&) void
-call_master_iter_end(RInside&, Field&, Field&) void
-RunMasterLoop(RInsidePOET&, RuntimeParameters&, DiffusionModule&, ChemistryModule&, ControlModule&) Rcpp::List
-getControlCellIds(vector~uint32_t~&, int, MPI_Comm) void
-getSpeciesNames(Field&&, int, MPI_Comm) vector~string~
-getBaseTotals(Field&&, int, MPI_Comm) array~double,2~
-getHasID(Field&&, int, MPI_Comm) bool
}
Main --> RuntimeParameters : uses
Main --> InitialList : creates
Main --> ChemistryModule : creates
Main --> DiffusionModule : creates
Main --> RInsidePOET : uses
Main --> Field : exchanges
InitialList --> RInsidePOET : uses
InitialList --> Field : creates
InitialList --> ChemistryInit : provides
InitialList --> DiffusionInit : provides
ChemistryModule --> Field : manages
ChemistryModule --> ChemistryInit : initialized with
ChemistryModule --> SurrogateSetup : configured with
DiffusionModule --> Field : manages
DiffusionModule --> DiffusionInit : initialized with
ChemistryModule ..> DiffusionModule : exchanges Field data
DiffusionModule ..> ChemistryModule : exchanges Field data
RuntimeParameters --> ChemistryInit : contains
```
## Key Relationships
- **Main** orchestrates the entire simulation, coordinating between modules
- **InitialList** parses R configuration and initializes all modules
- **ChemistryModule** and **DiffusionModule** exchange data via **Field** objects
- **Field** is the core data structure representing the simulation grid
- **RInsidePOET** provides the R runtime interface (singleton pattern)
- **RuntimeParameters** holds all command-line and configuration parameters
- **SurrogateSetup** configures advanced features (DHT, interpolation, AI surrogate)
## Module Communication Flow
1. Main reads configuration via `parseInitValues()`
2. `InitialList` imports R scripts and creates initial `Field`
3. `ChemistryModule` and `DiffusionModule` are initialized with their respective configurations
4. In simulation loop:
- `DiffusionModule.simulate()` updates transport field
- `ChemistryModule` receives updated field via `update()`
- `ChemistryModule.simulate()` computes chemistry
- `DiffusionModule` receives updated field back
5. MPI communication handled internally by modules

Binary file not shown.

View File

@ -6,7 +6,7 @@
namespace poet { namespace poet {
enum DHT_PROP_TYPES { DHT_TYPE_DEFAULT, DHT_TYPE_CHARGE, DHT_TYPE_TOTAL }; enum DHT_PROP_TYPES { DHT_TYPE_DEFAULT, DHT_TYPE_CHARGE, DHT_TYPE_TOTAL };
enum CHEMISTRY_OUT_SOURCE { CHEM_PQC, CHEM_DHT, CHEM_INTERP, CHEM_AISURR }; enum CHEMISTRY_OUT_SOURCE { CHEM_PQC, CHEM_DHT, CHEM_INTERP, CHEM_AISURR, CHEM_SKIP };
struct WorkPackage { struct WorkPackage {
std::size_t size; std::size_t size;

View File

@ -406,9 +406,10 @@ protected:
flags |= STAB_ENABLE; flags |= STAB_ENABLE;
return flags; return flags;
} }
inline bool hasFlag(int flags, int type) { return (flags & type) != 0; } inline bool hasFlag(int flags, int type) { return (flags & type) != 0; }
int comm_size, comm_rank; int comm_size, comm_rank;
MPI_Comm group_comm; MPI_Comm group_comm;
@ -437,6 +438,8 @@ protected:
ChemBCast(&type, 1, MPI_INT); ChemBCast(&type, 1, MPI_INT);
} }
double simtime = 0.; double simtime = 0.;
double idle_t = 0.; double idle_t = 0.;
double seq_t = 0.; double seq_t = 0.;

View File

@ -6,29 +6,25 @@
#include <mpi.h> #include <mpi.h>
#include <vector> #include <vector>
std::vector<uint32_t> std::vector<uint32_t> poet::ChemistryModule::MasterGatherWorkerMetrics(int type) const {
poet::ChemistryModule::MasterGatherWorkerMetrics(int type) const {
MPI_Bcast(&type, 1, MPI_INT, 0, this->group_comm); MPI_Bcast(&type, 1, MPI_INT, 0, this->group_comm);
uint32_t dummy; uint32_t dummy;
std::vector<uint32_t> metrics(this->comm_size); std::vector<uint32_t> metrics(this->comm_size);
MPI_Gather(&dummy, 1, MPI_UINT32_T, metrics.data(), 1, MPI_UINT32_T, 0, MPI_Gather(&dummy, 1, MPI_UINT32_T, metrics.data(), 1, MPI_UINT32_T, 0, this->group_comm);
this->group_comm);
metrics.erase(metrics.begin()); metrics.erase(metrics.begin());
return metrics; return metrics;
} }
std::vector<double> std::vector<double> poet::ChemistryModule::MasterGatherWorkerTimings(int type) const {
poet::ChemistryModule::MasterGatherWorkerTimings(int type) const {
MPI_Bcast(&type, 1, MPI_INT, 0, this->group_comm); MPI_Bcast(&type, 1, MPI_INT, 0, this->group_comm);
double dummy; double dummy;
std::vector<double> timings(this->comm_size); std::vector<double> timings(this->comm_size);
MPI_Gather(&dummy, 1, MPI_DOUBLE, timings.data(), 1, MPI_DOUBLE, 0, MPI_Gather(&dummy, 1, MPI_DOUBLE, timings.data(), 1, MPI_DOUBLE, 0, this->group_comm);
this->group_comm);
timings.erase(timings.begin()); timings.erase(timings.begin());
return timings; return timings;
@ -76,8 +72,8 @@ std::vector<uint32_t> poet::ChemistryModule::GetWorkerDHTHits() const {
MPI_Get_count(&probe, MPI_UINT32_T, &count); MPI_Get_count(&probe, MPI_UINT32_T, &count);
std::vector<uint32_t> ret(count); std::vector<uint32_t> ret(count);
MPI_Recv(ret.data(), count, MPI_UINT32_T, probe.MPI_SOURCE, WORKER_DHT_HITS, MPI_Recv(ret.data(), count, MPI_UINT32_T, probe.MPI_SOURCE, WORKER_DHT_HITS, this->group_comm,
this->group_comm, NULL); NULL);
return ret; return ret;
} }
@ -94,42 +90,37 @@ std::vector<uint32_t> poet::ChemistryModule::GetWorkerDHTEvictions() const {
MPI_Get_count(&probe, MPI_UINT32_T, &count); MPI_Get_count(&probe, MPI_UINT32_T, &count);
std::vector<uint32_t> ret(count); std::vector<uint32_t> ret(count);
MPI_Recv(ret.data(), count, MPI_UINT32_T, probe.MPI_SOURCE, MPI_Recv(ret.data(), count, MPI_UINT32_T, probe.MPI_SOURCE, WORKER_DHT_EVICTIONS,
WORKER_DHT_EVICTIONS, this->group_comm, NULL); this->group_comm, NULL);
return ret; return ret;
} }
std::vector<double> std::vector<double> poet::ChemistryModule::GetWorkerInterpolationWriteTimings() const {
poet::ChemistryModule::GetWorkerInterpolationWriteTimings() const {
int type = CHEM_PERF; int type = CHEM_PERF;
MPI_Bcast(&type, 1, MPI_INT, 0, this->group_comm); MPI_Bcast(&type, 1, MPI_INT, 0, this->group_comm);
return MasterGatherWorkerTimings(WORKER_IP_WRITE); return MasterGatherWorkerTimings(WORKER_IP_WRITE);
} }
std::vector<double> std::vector<double> poet::ChemistryModule::GetWorkerInterpolationReadTimings() const {
poet::ChemistryModule::GetWorkerInterpolationReadTimings() const {
int type = CHEM_PERF; int type = CHEM_PERF;
MPI_Bcast(&type, 1, MPI_INT, 0, this->group_comm); MPI_Bcast(&type, 1, MPI_INT, 0, this->group_comm);
return MasterGatherWorkerTimings(WORKER_IP_READ); return MasterGatherWorkerTimings(WORKER_IP_READ);
} }
std::vector<double> std::vector<double> poet::ChemistryModule::GetWorkerInterpolationGatherTimings() const {
poet::ChemistryModule::GetWorkerInterpolationGatherTimings() const {
int type = CHEM_PERF; int type = CHEM_PERF;
MPI_Bcast(&type, 1, MPI_INT, 0, this->group_comm); MPI_Bcast(&type, 1, MPI_INT, 0, this->group_comm);
return MasterGatherWorkerTimings(WORKER_IP_GATHER); return MasterGatherWorkerTimings(WORKER_IP_GATHER);
} }
std::vector<double> std::vector<double> poet::ChemistryModule::GetWorkerInterpolationFunctionCallTimings() const {
poet::ChemistryModule::GetWorkerInterpolationFunctionCallTimings() const {
int type = CHEM_PERF; int type = CHEM_PERF;
MPI_Bcast(&type, 1, MPI_INT, 0, this->group_comm); MPI_Bcast(&type, 1, MPI_INT, 0, this->group_comm);
return MasterGatherWorkerTimings(WORKER_IP_FC); return MasterGatherWorkerTimings(WORKER_IP_FC);
} }
std::vector<uint32_t> std::vector<uint32_t> poet::ChemistryModule::GetWorkerInterpolationCalls() const {
poet::ChemistryModule::GetWorkerInterpolationCalls() const {
int type = CHEM_PERF; int type = CHEM_PERF;
MPI_Bcast(&type, 1, MPI_INT, 0, this->group_comm); MPI_Bcast(&type, 1, MPI_INT, 0, this->group_comm);
type = WORKER_IP_CALLS; type = WORKER_IP_CALLS;
@ -141,8 +132,8 @@ poet::ChemistryModule::GetWorkerInterpolationCalls() const {
MPI_Get_count(&probe, MPI_UINT32_T, &count); MPI_Get_count(&probe, MPI_UINT32_T, &count);
std::vector<uint32_t> ret(count); std::vector<uint32_t> ret(count);
MPI_Recv(ret.data(), count, MPI_UINT32_T, probe.MPI_SOURCE, WORKER_IP_CALLS, MPI_Recv(ret.data(), count, MPI_UINT32_T, probe.MPI_SOURCE, WORKER_IP_CALLS, this->group_comm,
this->group_comm, NULL); NULL);
return ret; return ret;
} }
@ -159,14 +150,12 @@ std::vector<uint32_t> poet::ChemistryModule::GetWorkerPHTCacheHits() const {
MPI_Get_count(&probe, MPI_UINT32_T, &count); MPI_Get_count(&probe, MPI_UINT32_T, &count);
std::vector<uint32_t> ret(count); std::vector<uint32_t> ret(count);
MPI_Recv(ret.data(), count, MPI_UINT32_T, probe.MPI_SOURCE, type, MPI_Recv(ret.data(), count, MPI_UINT32_T, probe.MPI_SOURCE, type, this->group_comm, NULL);
this->group_comm, NULL);
return ret; return ret;
} }
inline std::vector<int> shuffleVector(const std::vector<int> &in_vector, inline std::vector<int> shuffleVector(const std::vector<int> &in_vector, uint32_t size_per_prop,
uint32_t size_per_prop,
uint32_t wp_count) { uint32_t wp_count) {
std::vector<int> out_buffer(in_vector.size()); std::vector<int> out_buffer(in_vector.size());
uint32_t write_i = 0; uint32_t write_i = 0;
@ -179,17 +168,14 @@ inline std::vector<int> shuffleVector(const std::vector<int> &in_vector,
return out_buffer; return out_buffer;
} }
inline std::vector<double> shuffleField(const std::vector<double> &in_field, inline std::vector<double> shuffleField(const std::vector<double> &in_field, uint32_t size_per_prop,
uint32_t size_per_prop, uint32_t species_count, uint32_t wp_count) {
uint32_t species_count,
uint32_t wp_count) {
std::vector<double> out_buffer(in_field.size()); std::vector<double> out_buffer(in_field.size());
uint32_t write_i = 0; uint32_t write_i = 0;
for (uint32_t i = 0; i < wp_count; i++) { for (uint32_t i = 0; i < wp_count; i++) {
for (uint32_t j = i; j < size_per_prop; j += wp_count) { for (uint32_t j = i; j < size_per_prop; j += wp_count) {
for (uint32_t k = 0; k < species_count; k++) { for (uint32_t k = 0; k < species_count; k++) {
out_buffer[(write_i * species_count) + k] = out_buffer[(write_i * species_count) + k] = in_field[(k * size_per_prop) + j];
in_field[(k * size_per_prop) + j];
} }
write_i++; write_i++;
} }
@ -197,16 +183,15 @@ inline std::vector<double> shuffleField(const std::vector<double> &in_field,
return out_buffer; return out_buffer;
} }
inline void unshuffleField(const std::vector<double> &in_buffer, inline void unshuffleField(const std::vector<double> &in_buffer, uint32_t size_per_prop,
uint32_t size_per_prop, uint32_t species_count, uint32_t species_count, uint32_t wp_count,
uint32_t wp_count, std::vector<double> &out_field) { std::vector<double> &out_field) {
uint32_t read_i = 0; uint32_t read_i = 0;
for (uint32_t i = 0; i < wp_count; i++) { for (uint32_t i = 0; i < wp_count; i++) {
for (uint32_t j = i; j < size_per_prop; j += wp_count) { for (uint32_t j = i; j < size_per_prop; j += wp_count) {
for (uint32_t k = 0; k < species_count; k++) { for (uint32_t k = 0; k < species_count; k++) {
out_field[(k * size_per_prop) + j] = out_field[(k * size_per_prop) + j] = in_buffer[(read_i * species_count) + k];
in_buffer[(read_i * species_count) + k];
} }
read_i++; read_i++;
} }
@ -232,11 +217,19 @@ inline void printProgressbar(int count_pkgs, int n_wp, int barWidth = 70) {
/* end visual progress */ /* end visual progress */
} }
inline void poet::ChemistryModule::MasterSendPkgs( /*
worker_list_t &w_list, workpointer_t &work_pointer,
workpointer_t &sur_pointer, int &pkg_to_send, int &count_pkgs, std::vector<std::vector<double>> extractSurCells(){
int &free_workers, double dt, uint32_t iteration,
const std::vector<uint32_t> &wp_sizes_vector) { }
*/
inline void poet::ChemistryModule::MasterSendPkgs(worker_list_t &w_list,
workpointer_t &work_pointer,
workpointer_t &sur_pointer, int &pkg_to_send,
int &count_pkgs, int &free_workers, double dt,
uint32_t iteration,
const std::vector<uint32_t> &wp_sizes_vector) {
/* declare variables */ /* declare variables */
int local_work_package_size; int local_work_package_size;
@ -250,9 +243,8 @@ inline void poet::ChemistryModule::MasterSendPkgs(
local_work_package_size = (int)wp_sizes_vector[count_pkgs]; local_work_package_size = (int)wp_sizes_vector[count_pkgs];
count_pkgs++; count_pkgs++;
uint32_t wp_start_index = uint32_t wp_start_index = std::accumulate(wp_sizes_vector.begin(),
std::accumulate(wp_sizes_vector.begin(), std::next(wp_sizes_vector.begin(), count_pkgs), 0);
std::next(wp_sizes_vector.begin(), count_pkgs), 0);
/* note current processed work package in workerlist */ /* note current processed work package in workerlist */
w_list[p].send_addr = work_pointer.base(); w_list[p].send_addr = work_pointer.base();
@ -290,8 +282,8 @@ inline void poet::ChemistryModule::MasterSendPkgs(
/* ATTENTION Worker p has rank p+1 */ /* ATTENTION Worker p has rank p+1 */
// MPI_Send(send_buffer, end_of_wp + BUFFER_OFFSET, MPI_DOUBLE, p + 1, // MPI_Send(send_buffer, end_of_wp + BUFFER_OFFSET, MPI_DOUBLE, p + 1,
// LOOP_WORK, this->group_comm); // LOOP_WORK, this->group_comm);
MPI_Send(send_buffer.data(), send_buffer.size(), MPI_DOUBLE, p + 1, MPI_Send(send_buffer.data(), send_buffer.size(), MPI_DOUBLE, p + 1, LOOP_WORK,
LOOP_WORK, this->group_comm); this->group_comm);
/* Mark that worker has work to do */ /* Mark that worker has work to do */
w_list[p].has_work = 1; w_list[p].has_work = 1;
@ -301,10 +293,8 @@ inline void poet::ChemistryModule::MasterSendPkgs(
} }
} }
inline void poet::ChemistryModule::MasterRecvPkgs(worker_list_t &w_list, inline void poet::ChemistryModule::MasterRecvPkgs(worker_list_t &w_list, int &pkg_to_recv,
int &pkg_to_recv, bool to_send, int &free_workers) {
bool to_send,
int &free_workers) {
/* declare most of the variables here */ /* declare most of the variables here */
int need_to_receive = 1; int need_to_receive = 1;
double idle_a, idle_b; double idle_a, idle_b;
@ -321,8 +311,7 @@ inline void poet::ChemistryModule::MasterRecvPkgs(worker_list_t &w_list,
// only of there are still packages to send and free workers are available // only of there are still packages to send and free workers are available
if (to_send && free_workers > 0) if (to_send && free_workers > 0)
// non blocking probing // non blocking probing
MPI_Iprobe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &need_to_receive, MPI_Iprobe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &need_to_receive, &probe_status);
&probe_status);
else { else {
idle_a = MPI_Wtime(); idle_a = MPI_Wtime();
// blocking probing // blocking probing
@ -341,8 +330,8 @@ inline void poet::ChemistryModule::MasterRecvPkgs(worker_list_t &w_list,
switch (probe_status.MPI_TAG) { switch (probe_status.MPI_TAG) {
case LOOP_WORK: { case LOOP_WORK: {
MPI_Get_count(&probe_status, MPI_DOUBLE, &size); MPI_Get_count(&probe_status, MPI_DOUBLE, &size);
MPI_Recv(w_list[p - 1].send_addr, size, MPI_DOUBLE, p, LOOP_WORK, MPI_Recv(w_list[p - 1].send_addr, size, MPI_DOUBLE, p, LOOP_WORK, this->group_comm,
this->group_comm, MPI_STATUS_IGNORE); MPI_STATUS_IGNORE);
// Only LOOP_WORK completes a work package // Only LOOP_WORK completes a work package
w_list[p - 1].has_work = 0; w_list[p - 1].has_work = 0;
pkg_to_recv -= 1; pkg_to_recv -= 1;
@ -354,21 +343,18 @@ inline void poet::ChemistryModule::MasterRecvPkgs(worker_list_t &w_list,
MPI_Get_count(&probe_status, MPI_DOUBLE, &size); MPI_Get_count(&probe_status, MPI_DOUBLE, &size);
recv_buffer.resize(size); recv_buffer.resize(size);
MPI_Recv(recv_buffer.data(), size, MPI_DOUBLE, p, LOOP_CTRL, MPI_Recv(recv_buffer.data(), size, MPI_DOUBLE, p, LOOP_CTRL, this->group_comm,
this->group_comm, MPI_STATUS_IGNORE); MPI_STATUS_IGNORE);
recv_ctrl_b = MPI_Wtime(); recv_ctrl_b = MPI_Wtime();
recv_ctrl_t += recv_ctrl_b - recv_ctrl_a; recv_ctrl_t += recv_ctrl_b - recv_ctrl_a;
// Collect PHREEQC rows for control cells // Collect PHREEQC rows for control cells
const std::size_t cells_per_batch = const std::size_t cells_per_batch = size / this->prop_count;
static_cast<std::size_t>(size) /
static_cast<std::size_t>(this->prop_count);
for (std::size_t i = 0; i < cells_per_batch; i++) { for (std::size_t i = 0; i < cells_per_batch; i++) {
std::vector<double> cell_output( std::vector<double> cell_output(recv_buffer.begin() + this->prop_count * i,
recv_buffer.begin() + this->prop_count * i, recv_buffer.begin() + this->prop_count * (i + 1));
recv_buffer.begin() + this->prop_count * (i + 1)); this->ctrl_batch.push_back(cell_output);
this->ctrl_batch.push_back(std::move(cell_output));
} }
break; break;
} }
@ -437,17 +423,18 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
ftype = CHEM_AI_BCAST_VALIDITY; ftype = CHEM_AI_BCAST_VALIDITY;
PropagateFunctionType(ftype); PropagateFunctionType(ftype);
this->ai_surrogate_validity_vector = this->ai_surrogate_validity_vector =
shuffleVector(this->ai_surrogate_validity_vector, this->n_cells, shuffleVector(this->ai_surrogate_validity_vector, this->n_cells, wp_sizes_vector.size());
wp_sizes_vector.size()); ChemBCast(&this->ai_surrogate_validity_vector.front(), this->n_cells, MPI_INT);
ChemBCast(&this->ai_surrogate_validity_vector.front(), this->n_cells,
MPI_INT);
} }
ftype = CHEM_CTRL_FLAGS; ftype = CHEM_CTRL_FLAGS;
PropagateFunctionType(ftype); PropagateFunctionType(ftype);
uint32_t ctrl_flags = buildCtrlFlags(this->dht_enabled, this->interp_enabled, uint32_t ctrl_flags = buildCtrlFlags(this->dht_enabled, this->interp_enabled, this->stab_enabled);
this->stab_enabled); ChemBCast(&ctrl_flags, 1, MPI_UINT32_T);
ChemBCast(&ctrl_flags, 1, MPI_INT); // std::cout << "[Master] Flags mask=" << ctrl_flags
// << " dht=" << this->dht_enabled
// << " ip=" << this->interp_enabled
// << " stab=" << this->stab_enabled << std::endl;
ftype = CHEM_WORK_LOOP; ftype = CHEM_WORK_LOOP;
PropagateFunctionType(ftype); PropagateFunctionType(ftype);
@ -462,8 +449,7 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
/* shuffle grid */ /* shuffle grid */
// grid.shuffleAndExport(mpi_buffer); // grid.shuffleAndExport(mpi_buffer);
std::vector<double> mpi_buffer = std::vector<double> mpi_buffer =
shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count, shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count, wp_sizes_vector.size());
wp_sizes_vector.size());
/* setup local variables */ /* setup local variables */
pkg_to_send = wp_sizes_vector.size(); pkg_to_send = wp_sizes_vector.size();
@ -493,8 +479,8 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
// while there are still packages to send // while there are still packages to send
if (pkg_to_send > 0) { if (pkg_to_send > 0) {
// send packages to all free workers ... // send packages to all free workers ...
MasterSendPkgs(worker_list, work_pointer, sur_pointer, pkg_to_send, MasterSendPkgs(worker_list, work_pointer, sur_pointer, pkg_to_send, i_pkgs, free_workers, dt,
i_pkgs, free_workers, dt, iteration, wp_sizes_vector); iteration, wp_sizes_vector);
} }
// ... and try to receive them from workers who has finished their work // ... and try to receive them from workers who has finished their work
MasterRecvPkgs(worker_list, pkg_to_recv, pkg_to_send > 0, free_workers); MasterRecvPkgs(worker_list, pkg_to_recv, pkg_to_send > 0, free_workers);
@ -513,8 +499,7 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
/* unshuffle grid */ /* unshuffle grid */
// grid.importAndUnshuffle(mpi_buffer); // grid.importAndUnshuffle(mpi_buffer);
std::vector<double> out_vec{mpi_buffer}; std::vector<double> out_vec{mpi_buffer};
unshuffleField(mpi_buffer, this->n_cells, this->prop_count, unshuffleField(mpi_buffer, this->n_cells, this->prop_count, wp_sizes_vector.size(), out_vec);
wp_sizes_vector.size(), out_vec);
chem_field = out_vec; chem_field = out_vec;
/* do master stuff */ /* do master stuff */
@ -523,7 +508,6 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
std::cout << "[Master] Processing " << this->ctrl_batch.size() std::cout << "[Master] Processing " << this->ctrl_batch.size()
<< " control cells for comparison." << std::endl; << " control cells for comparison." << std::endl;
std::vector<std::vector<double>> sur_batch; std::vector<std::vector<double>> sur_batch;
sur_batch.reserve(this->ctrl_batch.size()); sur_batch.reserve(this->ctrl_batch.size());
@ -534,9 +518,8 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
uint32_t curr_cell_id = mpi_buffer[this->prop_count * i]; uint32_t curr_cell_id = mpi_buffer[this->prop_count * i];
if (curr_cell_id == element[0]) { if (curr_cell_id == element[0]) {
std::vector<double> sur_output( std::vector<double> sur_output(mpi_buffer.begin() + this->prop_count * i,
mpi_buffer.begin() + this->prop_count * i, mpi_buffer.begin() + this->prop_count * (i + 1));
mpi_buffer.begin() + this->prop_count * (i + 1));
sur_batch.push_back(sur_output); sur_batch.push_back(sur_output);
break; break;
} }
@ -544,8 +527,8 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
} }
metrics_a = MPI_Wtime(); metrics_a = MPI_Wtime();
control->computeErrorMetrics(this->ctrl_batch, sur_batch, control->computeMetrics(this->ctrl_batch, sur_batch, prop_names, ctrl_cell_ids.size(),
prop_names, n_cells); ctrl_file_out_dir);
metrics_b = MPI_Wtime(); metrics_b = MPI_Wtime();
this->metrics_t += metrics_b - metrics_a; this->metrics_t += metrics_b - metrics_a;
@ -560,7 +543,7 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
this->seq_t += seq_d - seq_c; this->seq_t += seq_d - seq_c;
/* end time measurement of whole chemistry simulation */ /* end time measurement of whole chemistry simulation */
std::optional<uint32_t> target = control->getRollbackTarget(prop_names); std::optional<uint32_t> target = control->findRbTarget(prop_names);
int flush = target.has_value() ? 1 : 0; int flush = target.has_value() ? 1 : 0;
/* advise workers to end chemistry iteration */ /* advise workers to end chemistry iteration */
@ -583,12 +566,10 @@ void poet::ChemistryModule::MasterLoopBreak() {
MPI_Bcast(&type, 1, MPI_INT, 0, this->group_comm); MPI_Bcast(&type, 1, MPI_INT, 0, this->group_comm);
} }
std::vector<uint32_t> std::vector<uint32_t> poet::ChemistryModule::CalculateWPSizesVector(uint32_t n_cells,
poet::ChemistryModule::CalculateWPSizesVector(uint32_t n_cells, uint32_t wp_size) const {
uint32_t wp_size) const {
bool mod_pkgs = (n_cells % wp_size) != 0; bool mod_pkgs = (n_cells % wp_size) != 0;
uint32_t n_packages = uint32_t n_packages = (uint32_t)(n_cells / wp_size) + static_cast<int>(mod_pkgs);
(uint32_t)(n_cells / wp_size) + static_cast<int>(mod_pkgs);
std::vector<uint32_t> wp_sizes_vector(n_packages, 0); std::vector<uint32_t> wp_sizes_vector(n_packages, 0);

View File

@ -48,20 +48,19 @@ void poet::ChemistryModule::WorkerLoop() {
case CHEM_FIELD_INIT: { case CHEM_FIELD_INIT: {
ChemBCast(&this->prop_count, 1, MPI_UINT32_T); ChemBCast(&this->prop_count, 1, MPI_UINT32_T);
if (this->ai_surrogate_enabled) { if (this->ai_surrogate_enabled) {
this->ai_surrogate_validity_vector.resize( this->ai_surrogate_validity_vector.resize(this->n_cells); // resize statt reserve?
this->n_cells); // resize statt reserve?
} }
break; break;
} }
case CHEM_AI_BCAST_VALIDITY: { case CHEM_AI_BCAST_VALIDITY: {
// Receive the index vector of valid ai surrogate predictions // Receive the index vector of valid ai surrogate predictions
MPI_Bcast(&this->ai_surrogate_validity_vector.front(), this->n_cells, MPI_Bcast(&this->ai_surrogate_validity_vector.front(), this->n_cells, MPI_INT, 0,
MPI_INT, 0, this->group_comm); this->group_comm);
break; break;
} }
case CHEM_CTRL_FLAGS: { case CHEM_CTRL_FLAGS: {
int flags = 0; uint32_t flags = 0;
ChemBCast(&flags, 1, MPI_INT); ChemBCast(&flags, 1, MPI_UINT32_T);
this->dht_enabled = hasFlag(flags, DHT_ENABLE); this->dht_enabled = hasFlag(flags, DHT_ENABLE);
this->interp_enabled = hasFlag(flags, IP_ENABLE); this->interp_enabled = hasFlag(flags, IP_ENABLE);
this->stab_enabled = hasFlag(flags, STAB_ENABLE); this->stab_enabled = hasFlag(flags, STAB_ENABLE);
@ -94,8 +93,7 @@ void poet::ChemistryModule::WorkerLoop() {
} }
} }
void poet::ChemistryModule::WorkerProcessPkgs(struct worker_s &timings, void poet::ChemistryModule::WorkerProcessPkgs(struct worker_s &timings, uint32_t &iteration) {
uint32_t &iteration) {
MPI_Status probe_status; MPI_Status probe_status;
bool loop = true; bool loop = true;
@ -125,8 +123,7 @@ void poet::ChemistryModule::WorkerProcessPkgs(struct worker_s &timings,
} }
} }
void poet::ChemistryModule::copyPkgs(const WorkPackage &wp, void poet::ChemistryModule::copyPkgs(const WorkPackage &wp, std::vector<double> &mpi_buffer) {
std::vector<double> &mpi_buffer) {
for (std::size_t wp_i = 0; wp_i < wp.size; wp_i++) { for (std::size_t wp_i = 0; wp_i < wp.size; wp_i++) {
std::copy(wp.output[wp_i].begin(), wp.output[wp_i].end(), std::copy(wp.output[wp_i].begin(), wp.output[wp_i].end(),
@ -134,9 +131,9 @@ void poet::ChemistryModule::copyPkgs(const WorkPackage &wp,
} }
} }
void poet::ChemistryModule::processCtrlPkgs( void poet::ChemistryModule::processCtrlPkgs(std::vector<std::vector<double>> &input,
std::vector<std::vector<double>> &input, double current_sim_time, double dt, double current_sim_time, double dt,
struct worker_s &timings) { struct worker_s &timings) {
double phreeqc_start, phreeqc_end; double phreeqc_start, phreeqc_end;
@ -156,13 +153,12 @@ void poet::ChemistryModule::processCtrlPkgs(
copyPkgs(control_wp, mpi_buffer); copyPkgs(control_wp, mpi_buffer);
MPI_Request send_req; MPI_Request send_req;
MPI_Isend(mpi_buffer.data(), mpi_buffer.size(), MPI_DOUBLE, 0, LOOP_CTRL, MPI_Isend(mpi_buffer.data(), mpi_buffer.size(), MPI_DOUBLE, 0, LOOP_CTRL, MPI_COMM_WORLD,
MPI_COMM_WORLD, &send_req); &send_req);
MPI_Wait(&send_req, MPI_STATUS_IGNORE); MPI_Wait(&send_req, MPI_STATUS_IGNORE);
} }
void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, int double_count,
int double_count,
struct worker_s &timings) { struct worker_s &timings) {
static int counter = 1; static int counter = 1;
@ -179,12 +175,14 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
int flags; int flags;
std::vector<double> mpi_buffer(count); std::vector<double> mpi_buffer(count);
static int control_cells_processed = 0; static int ctrl_cells_processed = 0;
static std::vector<std::vector<double>> control_batch; static std::vector<std::vector<double>> ctrl_batch;
const int CL_INDEX = 7;
const double CL_THRESHOLD = 1e-10;
/* receive */ /* receive */
MPI_Recv(mpi_buffer.data(), count, MPI_DOUBLE, 0, LOOP_WORK, this->group_comm, MPI_Recv(mpi_buffer.data(), count, MPI_DOUBLE, 0, LOOP_WORK, this->group_comm, MPI_STATUS_IGNORE);
MPI_STATUS_IGNORE);
/* decrement count of work_package by BUFFER_OFFSET */ /* decrement count of work_package by BUFFER_OFFSET */
count -= BUFFER_OFFSET; count -= BUFFER_OFFSET;
@ -206,27 +204,21 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
// current work package start location in field // current work package start location in field
wp_start_index = mpi_buffer[count + 4]; wp_start_index = mpi_buffer[count + 4];
// read packed control flags
/*
flags = static_cast<int>(mpi_buffer[count + 5]);
this->interp_enabled = (flags & 1) != 0;
this->dht_enabled = (flags & 2) != 0;
this->warmup_enabled = (flags & 4) != 0;
this->control_enabled = (flags & 8) != 0;
*/
/*std::cout << "warmup_enabled is " << warmup_enabled << ", control_enabled is
"
<< control_enabled << ", dht_enabled is "
<< dht_enabled << ", interp_enabled is " << interp_enabled
<< std::endl;*/
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) { for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
s_curr_wp.input[wp_i] = s_curr_wp.input[wp_i] = std::vector<double>(mpi_buffer.begin() + this->prop_count * wp_i,
std::vector<double>(mpi_buffer.begin() + this->prop_count * wp_i, mpi_buffer.begin() + this->prop_count * (wp_i + 1));
mpi_buffer.begin() + this->prop_count * (wp_i + 1));
} }
/* skip simulation of cells cells where Cl concentration is below threshold */
/*
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
if (s_curr_wp.input[wp_i][CL_INDEX] < CL_THRESHOLD) {
s_curr_wp.mapping[wp_i] = CHEM_SKIP;
s_curr_wp.output[wp_i] = s_curr_wp.input[wp_i];
}
}
*/
// std::cout << this->comm_rank << ":" << counter++ << std::endl; // std::cout << this->comm_rank << ":" << counter++ << std::endl;
if (dht_enabled || interp_enabled || stab_enabled) { if (dht_enabled || interp_enabled || stab_enabled) {
dht->prepareKeys(s_curr_wp.input, dt); dht->prepareKeys(s_curr_wp.input, dt);
@ -253,28 +245,26 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
} }
} }
/* process cells to be monitored in a seperate workpackage */ // if (!this->stab_enabled) { /* process cells to be monitored in a seperate workpackage */
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) { for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
uint32_t cell_id = s_curr_wp.input[wp_i][0]; uint32_t cell_id = s_curr_wp.input[wp_i][0];
bool is_ctrl_cell = bool is_ctrl_cell = this->ctrl_cell_ids.find(cell_id) != this->ctrl_cell_ids.end();
this->ctrl_cell_ids.find(cell_id) != this->ctrl_cell_ids.end(); bool used_sur = (s_curr_wp.mapping[wp_i] != CHEM_PQC) && (s_curr_wp.mapping[wp_i] != CHEM_SKIP);
bool used_sur = s_curr_wp.mapping[wp_i] != CHEM_PQC;
if (is_ctrl_cell && used_sur) { if (is_ctrl_cell && used_sur) {
ctrl_batch.push_back(s_curr_wp.input[wp_i]);
ctrl_cells_processed++;
control_batch.push_back(s_curr_wp.input[wp_i]); if (ctrl_batch.size() == s_curr_wp.size ||
control_cells_processed++; ctrl_cells_processed == this->ctrl_cell_ids.size()) {
processCtrlPkgs(ctrl_batch, current_sim_time, dt, timings);
if (control_batch.size() == s_curr_wp.size || ctrl_batch.clear();
control_cells_processed == this->ctrl_cell_ids.size()) { ctrl_cells_processed = 0;
processCtrlPkgs(control_batch, current_sim_time, dt, timings);
control_batch.clear();
control_cells_processed = 0;
} }
} }
} }
// }
phreeqc_time_start = MPI_Wtime(); phreeqc_time_start = MPI_Wtime();
@ -286,8 +276,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
/* send results to master */ /* send results to master */
MPI_Request send_req; MPI_Request send_req;
MPI_Isend(mpi_buffer.data(), count, MPI_DOUBLE, 0, LOOP_WORK, MPI_COMM_WORLD, MPI_Isend(mpi_buffer.data(), count, MPI_DOUBLE, 0, LOOP_WORK, MPI_COMM_WORLD, &send_req);
&send_req);
if (dht_enabled || interp_enabled || stab_enabled) { if (dht_enabled || interp_enabled || stab_enabled) {
/* write results to DHT */ /* write results to DHT */
@ -305,17 +294,16 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
MPI_Wait(&send_req, MPI_STATUS_IGNORE); MPI_Wait(&send_req, MPI_STATUS_IGNORE);
} }
void poet::ChemistryModule::WorkerPostIter(MPI_Status &probe_status, void poet::ChemistryModule::WorkerPostIter(MPI_Status &probe_status, uint32_t iteration) {
uint32_t iteration) {
int size, flush_request = 0; int size, flush_request = 0;
MPI_Get_count(&probe_status, MPI_INT, &size); MPI_Get_count(&probe_status, MPI_INT, &size);
if (size == 1) { if (size == 1) {
MPI_Recv(&flush_request, size, MPI_INT, probe_status.MPI_SOURCE, LOOP_END, MPI_Recv(&flush_request, size, MPI_INT, probe_status.MPI_SOURCE, LOOP_END, this->group_comm,
this->group_comm, MPI_STATUS_IGNORE); MPI_STATUS_IGNORE);
} else { } else {
MPI_Recv(NULL, 0, MPI_INT, probe_status.MPI_SOURCE, LOOP_END, MPI_Recv(NULL, 0, MPI_INT, probe_status.MPI_SOURCE, LOOP_END, this->group_comm,
this->group_comm, MPI_STATUS_IGNORE); MPI_STATUS_IGNORE);
} }
if (this->dht_enabled) { if (this->dht_enabled) {
@ -334,13 +322,12 @@ void poet::ChemistryModule::WorkerPostIter(MPI_Status &probe_status,
interp->resetCounter(); interp->resetCounter();
interp->writePHTStats(); interp->writePHTStats();
if (this->dht_snaps_type == DHT_SNAPS_ITEREND) { if (this->dht_snaps_type == DHT_SNAPS_ITEREND) {
out << this->dht_file_out_dir << "/iter_" << std::setfill('0') out << this->dht_file_out_dir << "/iter_" << std::setfill('0') << std::setw(this->file_pad)
<< std::setw(this->file_pad) << iteration << ".pht"; << iteration << ".pht";
interp->dumpPHTState(out.str()); interp->dumpPHTState(out.str());
} }
const auto max_mean_idx = const auto max_mean_idx = DHT_get_used_idx_factor(this->interp->getDHTObject(), 1);
DHT_get_used_idx_factor(this->interp->getDHTObject(), 1);
if (max_mean_idx >= 2 || flush_request) { if (max_mean_idx >= 2 || flush_request) {
DHT_flush(this->interp->getDHTObject()); DHT_flush(this->interp->getDHTObject());
@ -360,33 +347,29 @@ void poet::ChemistryModule::WorkerPostSim(uint32_t iteration) {
} }
if (this->interp_enabled && this->dht_snaps_type >= DHT_SNAPS_ITEREND) { if (this->interp_enabled && this->dht_snaps_type >= DHT_SNAPS_ITEREND) {
std::stringstream out; std::stringstream out;
out << this->dht_file_out_dir << "/iter_" << std::setfill('0') out << this->dht_file_out_dir << "/iter_" << std::setfill('0') << std::setw(this->file_pad)
<< std::setw(this->file_pad) << iteration << ".pht"; << iteration << ".pht";
interp->dumpPHTState(out.str()); interp->dumpPHTState(out.str());
} }
} }
void poet::ChemistryModule::WorkerWriteDHTDump(uint32_t iteration) { void poet::ChemistryModule::WorkerWriteDHTDump(uint32_t iteration) {
std::stringstream out; std::stringstream out;
out << this->dht_file_out_dir << "/iter_" << std::setfill('0') out << this->dht_file_out_dir << "/iter_" << std::setfill('0') << std::setw(this->file_pad)
<< std::setw(this->file_pad) << iteration << ".dht"; << iteration << ".dht";
int res = dht->tableToFile(out.str().c_str()); int res = dht->tableToFile(out.str().c_str());
if (res != DHT_SUCCESS && this->comm_rank == 2) if (res != DHT_SUCCESS && this->comm_rank == 2)
std::cerr std::cerr << "CPP: Worker: Error in writing current state of DHT to file.\n";
<< "CPP: Worker: Error in writing current state of DHT to file.\n";
else if (this->comm_rank == 2) else if (this->comm_rank == 2)
std::cout << "CPP: Worker: Successfully written DHT to file " << out.str() std::cout << "CPP: Worker: Successfully written DHT to file " << out.str() << "\n";
<< "\n";
} }
void poet::ChemistryModule::WorkerReadDHTDump( void poet::ChemistryModule::WorkerReadDHTDump(const std::string &dht_input_file) {
const std::string &dht_input_file) {
int res = dht->fileToTable((char *)dht_input_file.c_str()); int res = dht->fileToTable((char *)dht_input_file.c_str());
if (res != DHT_SUCCESS) { if (res != DHT_SUCCESS) {
if (res == DHT_WRONG_FILE) { if (res == DHT_WRONG_FILE) {
if (this->comm_rank == 1) if (this->comm_rank == 1)
std::cerr std::cerr << "CPP: Worker: Wrong file layout! Continue with empty DHT ...\n";
<< "CPP: Worker: Wrong file layout! Continue with empty DHT ...\n";
} else { } else {
if (this->comm_rank == 1) if (this->comm_rank == 1)
std::cerr << "CPP: Worker: Error in loading current state of DHT from " std::cerr << "CPP: Worker: Error in loading current state of DHT from "
@ -394,13 +377,12 @@ void poet::ChemistryModule::WorkerReadDHTDump(
} }
} else { } else {
if (this->comm_rank == 2) if (this->comm_rank == 2)
std::cout << "CPP: Worker: Successfully loaded state of DHT from file " std::cout << "CPP: Worker: Successfully loaded state of DHT from file " << dht_input_file
<< dht_input_file << "\n"; << "\n";
} }
} }
void poet::ChemistryModule::WorkerRunWorkPackage(WorkPackage &work_package, void poet::ChemistryModule::WorkerRunWorkPackage(WorkPackage &work_package, double dSimTime,
double dSimTime,
double dTimestep) { double dTimestep) {
std::vector<std::vector<double>> inout_chem = work_package.input; std::vector<std::vector<double>> inout_chem = work_package.input;
@ -412,12 +394,13 @@ void poet::ChemistryModule::WorkerRunWorkPackage(WorkPackage &work_package,
} }
// HACK: remove the first element (cell_id) before sending to phreeqc // HACK: remove the first element (cell_id) before sending to phreeqc
inout_chem[wp_id].erase(inout_chem[wp_id].begin(), inout_chem[wp_id].erase(inout_chem[wp_id].begin(), inout_chem[wp_id].begin() + 1);
inout_chem[wp_id].begin() + 1);
} }
this->pqc_runner->run(inout_chem, dTimestep, to_ignore); this->pqc_runner->run(inout_chem, dTimestep, to_ignore);
//std::cout << "Ignored " << to_ignore.size() << " cells out of " << wp_size << "." << std::endl;
for (std::size_t wp_id = 0; wp_id < work_package.size; wp_id++) { for (std::size_t wp_id = 0; wp_id < work_package.size; wp_id++) {
if (work_package.mapping[wp_id] == CHEM_PQC) { if (work_package.mapping[wp_id] == CHEM_PQC) {
// HACK: as we removed the first element (cell_id) before sending to // HACK: as we removed the first element (cell_id) before sending to
@ -429,32 +412,26 @@ void poet::ChemistryModule::WorkerRunWorkPackage(WorkPackage &work_package,
} }
} }
void poet::ChemistryModule::WorkerPerfToMaster(int type, void poet::ChemistryModule::WorkerPerfToMaster(int type, const struct worker_s &timings) {
const struct worker_s &timings) {
switch (type) { switch (type) {
case WORKER_PHREEQC: { case WORKER_PHREEQC: {
MPI_Gather(&timings.phreeqc_t, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0, MPI_Gather(&timings.phreeqc_t, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0, this->group_comm);
this->group_comm);
break; break;
} }
case WORKER_CTRL_ITER: { case WORKER_CTRL_ITER: {
MPI_Gather(&timings.ctrl_t, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0, MPI_Gather(&timings.ctrl_t, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0, this->group_comm);
this->group_comm);
break; break;
} }
case WORKER_DHT_GET: { case WORKER_DHT_GET: {
MPI_Gather(&timings.dht_get, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0, MPI_Gather(&timings.dht_get, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0, this->group_comm);
this->group_comm);
break; break;
} }
case WORKER_DHT_FILL: { case WORKER_DHT_FILL: {
MPI_Gather(&timings.dht_fill, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0, MPI_Gather(&timings.dht_fill, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0, this->group_comm);
this->group_comm);
break; break;
} }
case WORKER_IDLE: { case WORKER_IDLE: {
MPI_Gather(&timings.idle_t, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0, MPI_Gather(&timings.idle_t, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0, this->group_comm);
this->group_comm);
break; break;
} }
case WORKER_IP_WRITE: { case WORKER_IP_WRITE: {
@ -490,15 +467,14 @@ void poet::ChemistryModule::WorkerMetricsToMaster(int type) {
MPI_Comm &group_comm = this->group_comm; MPI_Comm &group_comm = this->group_comm;
auto reduce_and_send = [&worker_rank, &worker_comm, &group_comm]( auto reduce_and_send = [&worker_rank, &worker_comm,
std::vector<std::uint32_t> &send_buffer, int tag) { &group_comm](std::vector<std::uint32_t> &send_buffer, int tag) {
std::vector<uint32_t> to_master(send_buffer.size()); std::vector<uint32_t> to_master(send_buffer.size());
MPI_Reduce(send_buffer.data(), to_master.data(), send_buffer.size(), MPI_Reduce(send_buffer.data(), to_master.data(), send_buffer.size(), MPI_UINT32_T, MPI_SUM, 0,
MPI_UINT32_T, MPI_SUM, 0, worker_comm); worker_comm);
if (worker_rank == 0) { if (worker_rank == 0) {
MPI_Send(to_master.data(), to_master.size(), MPI_UINT32_T, 0, tag, MPI_Send(to_master.data(), to_master.size(), MPI_UINT32_T, 0, tag, group_comm);
group_comm);
} }
}; };

View File

@ -10,33 +10,33 @@ poet::ControlModule::ControlModule(const ControlConfig &config_) : config(config
void poet::ControlModule::beginIteration(ChemistryModule &chem, uint32_t &iter, void poet::ControlModule::beginIteration(ChemistryModule &chem, uint32_t &iter,
const bool &dht_enabled, const bool &interp_enabled) { const bool &dht_enabled, const bool &interp_enabled) {
global_iteration = iter; global_iter = iter;
double prep_a, prep_b; double prep_a, prep_b;
prep_a = MPI_Wtime(); prep_a = MPI_Wtime();
updateStabilizationPhase(chem, dht_enabled, interp_enabled); updateSurrState(chem, dht_enabled, interp_enabled);
prep_b = MPI_Wtime(); prep_b = MPI_Wtime();
this->prep_t += prep_b - prep_a; this->prep_t += prep_b - prep_a;
} }
/* Disables dht and/or interp during stabilzation phase */ /* Disables dht and/or interp during stabilzation phase */
void poet::ControlModule::updateStabilizationPhase(ChemistryModule &chem, bool dht_enabled, void poet::ControlModule::updateSurrState(ChemistryModule &chem, bool dht_enabled,
bool interp_enabled) { bool interp_enabled) {
bool in_warmup = (global_iteration <= config.stab_interval); bool in_warmup = (global_iter <= config.stab_interval);
bool rb_limit_reached = (rollback_count >= 3); bool rb_limit_reached = (rb_count >= config.rb_limit);
if (rollback_enabled && disable_surr_counter > 0) { if (rb_enabled && stab_countdown > 0) {
--disable_surr_counter; --stab_countdown;
std::cout << "Rollback counter: " << std::to_string(disable_surr_counter) << std::endl; std::cout << "Rollback counter: " << std::to_string(stab_countdown) << std::endl;
if (disable_surr_counter == 0) { if (stab_countdown == 0) {
rollback_enabled = false; rb_enabled = false;
} }
flush_request = false; flush_request = false;
} }
/* disable surrogates during warmup, active rollback or after limit */ /* disable surrogates during warmup, active rollback or after limit */
if (in_warmup || rollback_enabled || rb_limit_reached) { if (in_warmup || rb_enabled || rb_limit_reached) {
chem.SetStabEnabled(!rb_limit_reached); chem.SetStabEnabled(!rb_limit_reached);
chem.SetDhtEnabled(false); chem.SetDhtEnabled(false);
chem.SetInterpEnabled(false); chem.SetInterpEnabled(false);
@ -58,13 +58,12 @@ void poet::ControlModule::updateStabilizationPhase(ChemistryModule &chem, bool d
void poet::ControlModule::writeCheckpoint(DiffusionModule &diffusion, uint32_t &iter, void poet::ControlModule::writeCheckpoint(DiffusionModule &diffusion, uint32_t &iter,
const std::string &out_dir) { const std::string &out_dir) {
if (global_iteration % config.checkpoint_interval == 0) { if (global_iter % config.chkpt_interval == 0) {
double w_check_a = MPI_Wtime(); double w_check_a = MPI_Wtime();
write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5", write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
{.field = diffusion.getField(), .iteration = iter}); {.field = diffusion.getField(), .iteration = iter});
double w_check_b = MPI_Wtime(); double w_check_b = MPI_Wtime();
this->w_check_t += w_check_b - w_check_a; this->w_check_t += w_check_b - w_check_a;
last_checkpoint_written = iter;
} }
} }
@ -80,37 +79,39 @@ void poet::ControlModule::readCheckpoint(DiffusionModule &diffusion, uint32_t &c
r_check_t += r_check_b - r_check_a; r_check_t += r_check_b - r_check_a;
} }
void poet::ControlModule::writeErrorMetrics(uint32_t &iter, const std::string &out_dir, void poet::ControlModule::writeMetrics(uint32_t &iter, const std::string &out_dir,
const std::vector<std::string> &species) { const std::vector<std::string> &species) {
if (rollback_count >= 3) { if (rb_count >= config.rb_limit || global_iter <= config.stab_interval) {
return;
}
if (rb_enabled) {
return; return;
} }
double stats_a = MPI_Wtime(); double stats_a = MPI_Wtime();
writeSpeciesStatsToCSV(metrics_history, species, out_dir, "species_overview.csv"); writeSpeciesStatsToCSV(s_history, species, out_dir, "species_overview.csv");
write_metrics(cell_metrics_history, species, out_dir, "metrics_overview.hdf5"); write_metrics(c_history, species, out_dir, "metrics_overview.hdf5");
double stats_b = MPI_Wtime(); double stats_b = MPI_Wtime();
this->stats_t += stats_b - stats_a; this->stats_t += stats_b - stats_a;
} }
uint32_t poet::ControlModule::getRollbackIter() { uint32_t poet::ControlModule::calcRbIter() {
uint32_t last_iter = ((global_iter - 1) / config.chkpt_interval) * config.chkpt_interval;
uint32_t last_iter =
((global_iteration - 1) / config.checkpoint_interval) * config.checkpoint_interval;
return last_iter; return last_iter;
} }
std::optional<uint32_t> std::optional<uint32_t> poet::ControlModule::findRbTarget(const std::vector<std::string> &species) {
poet::ControlModule::getRollbackTarget(const std::vector<std::string> &species) {
double r_check_a, r_check_b; double r_check_a, r_check_b;
/* Skip threshold checking if already in stabilization phase*/ /* Skip threshold checking if already in stabilization phase*/
if (metrics_history.empty() || rollback_enabled) { if (s_history.empty() || rb_enabled) {
return std::nullopt; return std::nullopt;
} }
const auto &s_hist = metrics_history.back(); const auto &s_hist = s_history.back();
/* skipping cell_id and id */ /* skipping cell_id and id */
for (size_t sp_i = 2; sp_i < species.size(); sp_i++) { for (size_t sp_i = 2; sp_i < species.size(); sp_i++) {
@ -123,7 +124,7 @@ poet::ControlModule::getRollbackTarget(const std::vector<std::string> &species)
} }
if (s_hist.mape[sp_i] > config.mape_threshold[sp_i]) { if (s_hist.mape[sp_i] > config.mape_threshold[sp_i]) {
const auto &c_hist = cell_metrics_history.back(); const auto &c_hist = c_history.back();
auto max_it = auto max_it =
std::max_element(c_hist.mape.begin(), c_hist.mape.end(), std::max_element(c_hist.mape.begin(), c_hist.mape.end(),
[sp_i](const auto &a, const auto &b) { return a[sp_i] < b[sp_i]; }); [sp_i](const auto &a, const auto &b) { return a[sp_i] < b[sp_i]; });
@ -132,7 +133,6 @@ poet::ControlModule::getRollbackTarget(const std::vector<std::string> &species)
uint32_t cell_id = c_hist.id[max_idx]; uint32_t cell_id = c_hist.id[max_idx];
double cell_mape = (*max_it)[sp_i]; double cell_mape = (*max_it)[sp_i];
rollback_enabled = true;
flush_request = true; flush_request = true;
std::cout << "Threshold exceeded for " << species[sp_i] std::cout << "Threshold exceeded for " << species[sp_i]
@ -141,18 +141,17 @@ poet::ControlModule::getRollbackTarget(const std::vector<std::string> &species)
<< ". Worst cell: ID=" << std::to_string(cell_id) << ". Worst cell: ID=" << std::to_string(cell_id)
<< " with MAPE=" << std::to_string(cell_mape) << std::endl; << " with MAPE=" << std::to_string(cell_mape) << std::endl;
return getRollbackIter(); return calcRbIter();
} }
} }
return std::nullopt; return std::nullopt;
} }
void poet::ControlModule::computeErrorMetrics(std::vector<std::vector<double>> &reference_values, void poet::ControlModule::computeMetrics(std::vector<std::vector<double>> &reference_values,
std::vector<std::vector<double>> &surrogate_values, std::vector<std::vector<double>> &surrogate_values,
const std::vector<std::string> &species, const std::vector<std::string> &species,
const uint32_t size_per_prop) { const uint32_t size_per_prop, const std::string &out_dir) {
if (rb_count >= config.rb_limit) {
if (rollback_count >= 3) {
return; return;
} }
@ -160,8 +159,8 @@ void poet::ControlModule::computeErrorMetrics(std::vector<std::vector<double>> &
const uint32_t n_species = species.size(); const uint32_t n_species = species.size();
const double ZERO_ABS = config.zero_abs; const double ZERO_ABS = config.zero_abs;
CellErrorMetrics c_metrics(n_cells, n_species, global_iteration, rollback_count); CellMetrics c_metrics(n_cells, n_species, global_iter, rb_count);
SpeciesErrorMetrics s_metrics(n_species, global_iteration, rollback_count); SpeciesMetrics s_metrics(n_species, global_iter, rb_count);
std::vector<double> species_err_sum(n_species, 0.0); std::vector<double> species_err_sum(n_species, 0.0);
std::vector<double> species_sqr_sum(n_species, 0.0); std::vector<double> species_sqr_sum(n_species, 0.0);
@ -177,6 +176,15 @@ void poet::ControlModule::computeErrorMetrics(std::vector<std::vector<double>> &
const double sur_value = surrogate_values[cell_i][sp_i]; const double sur_value = surrogate_values[cell_i][sp_i];
if (std::isnan(ref_value) || std::isnan(sur_value)) { if (std::isnan(ref_value) || std::isnan(sur_value)) {
// Initialize to 0 for NaN cases to avoid uninitialized values
c_metrics.mape[cell_i][sp_i] = 0.0;
c_metrics.rrmse[cell_i][sp_i] = 0.0;
std::cout << "WARNING: NaN detected - Cell=" << reference_values[cell_i][0]
<< ", Species=" << species[sp_i]
<< ", Ref=" << (std::isnan(ref_value) ? "NaN" : std::to_string(ref_value))
<< ", Sur=" << (std::isnan(sur_value) ? "NaN" : std::to_string(sur_value))
<< std::endl;
continue; continue;
} }
if (std::abs(ref_value) < ZERO_ABS) { if (std::abs(ref_value) < ZERO_ABS) {
@ -187,6 +195,10 @@ void poet::ControlModule::computeErrorMetrics(std::vector<std::vector<double>> &
c_metrics.mape[cell_i][sp_i] = 100.0; c_metrics.mape[cell_i][sp_i] = 100.0;
c_metrics.rrmse[cell_i][sp_i] = 1.0; c_metrics.rrmse[cell_i][sp_i] = 1.0;
} else {
// Both values are near zero, initialize to 0
c_metrics.mape[cell_i][sp_i] = 0.0;
c_metrics.rrmse[cell_i][sp_i] = 0.0;
} }
} else { } else {
double alpha = 1.0 - (sur_value / ref_value); double alpha = 1.0 - (sur_value / ref_value);
@ -232,32 +244,36 @@ void poet::ControlModule::computeErrorMetrics(std::vector<std::vector<double>> &
c_metrics.mape = std::move(sorted_mape); c_metrics.mape = std::move(sorted_mape);
c_metrics.rrmse = std::move(sorted_rrmse); c_metrics.rrmse = std::move(sorted_rrmse);
metrics_history.push_back(s_metrics); s_history.push_back(s_metrics);
cell_metrics_history.push_back(c_metrics); c_history.push_back(c_metrics);
writeMetrics(global_iter, out_dir, species);
} }
void poet::ControlModule::processCheckpoint(DiffusionModule &diffusion, uint32_t &current_iter, void poet::ControlModule::processCheckpoint(DiffusionModule &diffusion, uint32_t &current_iter,
const std::string &out_dir, const std::string &out_dir,
const std::vector<std::string> &species) { const std::vector<std::string> &species) {
if (rb_count >= config.rb_limit) {
// Use max_rollbacks from config, default to 3 if not set
// uint32_t max_rollbacks =
// (config.max_rollbacks > 0) ? config.max_rollbacks : 3;
if (rollback_count >= 3) {
return; return;
} }
if (flush_request && rollback_count < 3) { if (flush_request && rb_count < config.rb_limit) {
uint32_t target = getRollbackIter(); uint32_t target = calcRbIter();
readCheckpoint(diffusion, current_iter, target, out_dir); readCheckpoint(diffusion, current_iter, target, out_dir);
rollback_enabled = true; rb_enabled = true;
rollback_count++; rb_count++;
disable_surr_counter = config.stab_interval; stab_countdown = config.stab_interval;
std::cout << "Restored checkpoint " << std::to_string(target) << ", surrogate disabled for " std::cout << "Restored checkpoint " << std::to_string(target) << ", surrogate disabled for "
<< std::to_string(config.stab_interval) << std::endl; << std::to_string(config.stab_interval) << std::endl;
} else { } else {
writeCheckpoint(diffusion, global_iteration, out_dir); writeCheckpoint(diffusion, global_iter, out_dir);
} }
} }
bool poet::ControlModule::needsFlagBcast() const {
if (rb_count >= config.rb_limit) {
return false;
}
return true;
}

View File

@ -19,35 +19,32 @@ class DiffusionModule;
struct ControlConfig { struct ControlConfig {
uint32_t stab_interval = 0; uint32_t stab_interval = 0;
uint32_t checkpoint_interval = 0; // How often to write metrics files uint32_t chkpt_interval = 0;
//uint32_t max_rb = 0; // Maximum number of rollbacks allowed uint32_t rb_limit = 0;
double zero_abs = 0.0; double zero_abs = 0.0;
std::vector<double> mape_threshold; std::vector<double> mape_threshold;
}; };
struct CellErrorMetrics { struct CellMetrics {
std::vector<std::uint32_t> id; std::vector<std::uint32_t> id;
std::vector<std::vector<double>> mape; std::vector<std::vector<double>> mape;
std::vector<std::vector<double>> rrmse; std::vector<std::vector<double>> rrmse;
uint32_t iteration = 0; uint32_t iteration = 0;
uint32_t rollback_count = 0; uint32_t rb_count = 0;
CellErrorMetrics(uint32_t n_cells, uint32_t n_species, uint32_t iter, CellMetrics(uint32_t n_cells, uint32_t n_species, uint32_t iter, uint32_t rb_count)
uint32_t rb_count)
: mape(n_cells, std::vector<double>(n_species, 0.0)), : mape(n_cells, std::vector<double>(n_species, 0.0)),
rrmse(n_cells, std::vector<double>(n_species, 0.0)), iteration(iter), rrmse(n_cells, std::vector<double>(n_species, 0.0)), iteration(iter), rb_count(rb_count) {}
rollback_count(rb_count) {}
}; };
struct SpeciesErrorMetrics { struct SpeciesMetrics {
std::vector<double> mape; std::vector<double> mape;
std::vector<double> rrmse; std::vector<double> rrmse;
uint32_t iteration = 0; uint32_t iteration = 0;
uint32_t rollback_count = 0; uint32_t rb_count = 0;
SpeciesErrorMetrics(uint32_t n_species, uint32_t iter, uint32_t rb_count) SpeciesMetrics(uint32_t n_species, uint32_t iter, uint32_t count)
: mape(n_species, 0.0), rrmse(n_species, 0.0), iteration(iter), : mape(n_species, 0.0), rrmse(n_species, 0.0), iteration(iter), rb_count(count) {}
rollback_count(rb_count) {}
}; };
class ControlModule { class ControlModule {
@ -55,75 +52,63 @@ class ControlModule {
public: public:
explicit ControlModule(const ControlConfig &config); explicit ControlModule(const ControlConfig &config);
void beginIteration(ChemistryModule &chem, uint32_t &iter, void beginIteration(ChemistryModule &chem, uint32_t &iter, const bool &dht_enabled,
const bool &dht_enabled, const bool &interp_enaled); const bool &interp_enaled);
void writeErrorMetrics(uint32_t &iter, const std::string &out_dir, void writeMetrics(uint32_t &iter, const std::string &out_dir,
const std::vector<std::string> &species); const std::vector<std::string> &species);
std::optional<uint32_t> getRollbackTarget(); void computeMetrics(std::vector<std::vector<double>> &reference_values,
std::vector<std::vector<double>> &surrogate_values,
void computeErrorMetrics(std::vector<std::vector<double>> &reference_values, const std::vector<std::string> &species, const uint32_t size_per_prop,
std::vector<std::vector<double>> &surrogate_values, const std::string &out_dir);
const std::vector<std::string> &species,
const uint32_t size_per_prop);
void processCheckpoint(DiffusionModule &diffusion, uint32_t &current_iter, void processCheckpoint(DiffusionModule &diffusion, uint32_t &current_iter,
const std::string &out_dir, const std::string &out_dir, const std::vector<std::string> &species);
const std::vector<std::string> &species);
std::optional<uint32_t> std::optional<uint32_t> findRbTarget(const std::vector<std::string> &species);
getRollbackTarget(const std::vector<std::string> &species);
bool shouldBcastFlags();
bool getFlushRequest() const { return flush_request; } bool getFlushRequest() const { return flush_request; }
void clearFlushRequest() { flush_request = false; } void clearFlushRequest() { flush_request = false; }
auto getGlobalIteration() const noexcept { return global_iteration; } auto getGlobalIteration() const noexcept { return global_iter; }
// void setChemistryModule(poet::ChemistryModule *c) { chem = c; } // void setChemistryModule(poet::ChemistryModule *c) { chem = c; }
std::vector<double> getMapeThreshold() const { std::vector<double> getMapeThreshold() const { return this->config.mape_threshold; }
return this->config.mape_threshold;
}
std::vector<uint32_t> getCtrlCellIds() const { return this->ctrl_cell_ids; } std::vector<uint32_t> getCtrlCellIds() const { return this->ctrl_cell_ids; }
bool needsFlagBcast() const;
/* Profiling getters */ /* Profiling getters */
auto getUpdateCtrlLogicTime() const { return prep_t; } auto getCtrlLogicTime() const { return prep_t; }
auto getWriteCheckpointTime() const { return w_check_t; } auto getChkptWriteTime() const { return w_check_t; }
auto getReadCheckpointTime() const { return r_check_t; } auto getChkptReadTime() const { return r_check_t; }
auto getWriteMetricsTime() const { return stats_t; } auto getMetricsWriteTime() const { return stats_t; }
private: private:
void updateStabilizationPhase(ChemistryModule &chem, bool dht_enabled, void updateSurrState(ChemistryModule &chem, bool dht_enabled, bool interp_enabled);
bool interp_enabled);
void readCheckpoint(DiffusionModule &diffusion, uint32_t &current_iter, void readCheckpoint(DiffusionModule &diffusion, uint32_t &current_iter, uint32_t rollback_iter,
uint32_t rollback_iter, const std::string &out_dir); const std::string &out_dir);
void writeCheckpoint(DiffusionModule &diffusion, uint32_t &iter, void writeCheckpoint(DiffusionModule &diffusion, uint32_t &iter, const std::string &out_dir);
const std::string &out_dir);
uint32_t getRollbackIter(); uint32_t calcRbIter();
ControlConfig config; ControlConfig config;
std::uint32_t global_iteration = 0; std::uint32_t global_iter = 0;
std::uint32_t rollback_count = 0; std::uint32_t rb_count = 0;
std::uint32_t disable_surr_counter = 0; std::uint32_t rb_limit = 0;
std::uint32_t stab_countdown = 0;
std::vector<uint32_t> ctrl_cell_ids; std::vector<uint32_t> ctrl_cell_ids;
std::uint32_t last_checkpoint_written = 0;
std::uint32_t penalty_interval = 0;
bool rollback_enabled = false; bool rb_enabled = false;
bool flush_request = false; bool flush_request = false;
bool stab_phase_ended = false;
bool bcast_flags = false; std::vector<CellMetrics> c_history;
std::vector<SpeciesMetrics> s_history;
std::vector<CellErrorMetrics> cell_metrics_history;
std::vector<SpeciesErrorMetrics> metrics_history;
double prep_t = 0.; double prep_t = 0.;
double r_check_t = 0.; double r_check_t = 0.;

View File

@ -5,8 +5,8 @@
#include <vector> #include <vector>
namespace poet { namespace poet {
struct SpeciesErrorMetrics; struct SpeciesMetrics;
struct CellErrorMetrics; struct CellMetrics;
} }
int write_checkpoint(const std::string &dir_path, const std::string &file_name, int write_checkpoint(const std::string &dir_path, const std::string &file_name,
@ -15,6 +15,6 @@ int write_checkpoint(const std::string &dir_path, const std::string &file_name,
int read_checkpoint(const std::string &dir_path, const std::string &file_name, int read_checkpoint(const std::string &dir_path, const std::string &file_name,
struct Checkpoint_s &checkpoint); struct Checkpoint_s &checkpoint);
int write_metrics(const std::vector<poet::CellErrorMetrics> &metrics_history, int write_metrics(const std::vector<poet::CellMetrics> &metrics_history,
const std::vector<std::string> &species_names, const std::vector<std::string> &species_names,
const std::string &dir_path, const std::string &file_name); const std::string &dir_path, const std::string &file_name);

View File

@ -13,7 +13,7 @@
namespace fs = std::filesystem; namespace fs = std::filesystem;
int write_metrics(const std::vector<poet::CellErrorMetrics> &metrics_history, int write_metrics(const std::vector<poet::CellMetrics> &metrics_history,
const std::vector<std::string> &species_names, const std::vector<std::string> &species_names,
const std::string &dir_path, const std::string &file_name) { const std::string &dir_path, const std::string &file_name) {
@ -23,7 +23,8 @@ int write_metrics(const std::vector<poet::CellErrorMetrics> &metrics_history,
} }
fs::path file_path = fs::path(dir_path) / file_name; fs::path file_path = fs::path(dir_path) / file_name;
H5Easy::File file(file_path, H5Easy::File::Truncate); // Use a std::string path to avoid filesystem path conversion issues.
H5Easy::File file(file_path.string(), H5Easy::File::Truncate);
for (size_t idx = 0; idx < metrics_history.size(); ++idx) { for (size_t idx = 0; idx < metrics_history.size(); ++idx) {
const auto &metrics = metrics_history[idx]; const auto &metrics = metrics_history[idx];
@ -33,38 +34,44 @@ int write_metrics(const std::vector<poet::CellErrorMetrics> &metrics_history,
std::to_string(metrics.rollback_count); std::to_string(metrics.rollback_count);
*/ */
std::string grp = "iter_" + std::to_string(metrics.iteration) + "_rb_" + std::string grp = "iter_" + std::to_string(metrics.iteration) + "_rb_" +
std::to_string(metrics.rollback_count); std::to_string(metrics.rb_count);
size_t n_cells = metrics.id.size(); size_t n_cells = metrics.id.size();
size_t n_species = metrics.mape[0].size(); // Use provided species_names as the source of truth to avoid OOB when mape is empty.
size_t n_species = species_names.size();
// Create a scalar dataset "meta" and attach attributes to it (no explicit groups).
H5Easy::dump(file, grp + "/meta", 0, H5Easy::DumpMode::Overwrite); H5Easy::dump(file, grp + "/meta", 0, H5Easy::DumpMode::Overwrite);
// Attach attributes
H5Easy::dumpAttribute(file, grp + "/meta", "species_names", species_names, H5Easy::dumpAttribute(file, grp + "/meta", "species_names", species_names,
H5Easy::DumpMode::Overwrite); H5Easy::DumpMode::Overwrite);
H5Easy::dumpAttribute(file, grp + "/meta", "iteration", metrics.iteration, H5Easy::dumpAttribute(file, grp + "/meta", "iteration", metrics.iteration,
H5Easy::DumpMode::Overwrite); H5Easy::DumpMode::Overwrite);
H5Easy::dumpAttribute(file, grp + "/meta", "rollback_count", H5Easy::dumpAttribute(file, grp + "/meta", "rollback_count",
metrics.rollback_count, H5Easy::DumpMode::Overwrite); metrics.rb_count, H5Easy::DumpMode::Overwrite);
H5Easy::dumpAttribute(file, grp + "/meta", "n_cells", n_cells, H5Easy::dumpAttribute(file, grp + "/meta", "n_cells", n_cells,
H5Easy::DumpMode::Overwrite); H5Easy::DumpMode::Overwrite);
H5Easy::dumpAttribute(file, grp + "/meta", "n_species", n_species, H5Easy::dumpAttribute(file, grp + "/meta", "n_species", n_species,
H5Easy::DumpMode::Overwrite); H5Easy::DumpMode::Overwrite);
// ───────────────────────────────────────────── // Dump only if data is non-empty to avoid corrupting the file on failures.
// 2. Real datasets if (!metrics.mape.empty()) {
// ───────────────────────────────────────────── H5Easy::dump(file, grp + "/mape", metrics.mape,
H5Easy::dump(file, grp + "/mape", metrics.mape, H5Easy::DumpMode::Overwrite);
H5Easy::DumpMode::Overwrite); }
H5Easy::dump(file, grp + "/rrmse", metrics.rrmse, if (!metrics.rrmse.empty()) {
H5Easy::DumpMode::Overwrite); H5Easy::dump(file, grp + "/rrmse", metrics.rrmse,
H5Easy::DumpMode::Overwrite);
}
} }
// Ensure all buffers are flushed to disk.
file.flush();
return 0; return 0;
} }
void writeCellStatsToCSV(const std::vector<poet::CellErrorMetrics> &all_stats, void writeCellStatsToCSV(const std::vector<poet::CellMetrics> &all_stats,
const std::vector<std::string> &species_names, const std::vector<std::string> &species_names,
const std::string &out_dir, const std::string &out_dir,
const std::string &filename) { const std::string &filename) {
@ -81,15 +88,16 @@ void writeCellStatsToCSV(const std::vector<poet::CellErrorMetrics> &all_stats,
<< "\n" << "\n"
<< std::string(90, '-') << "\n"; << std::string(90, '-') << "\n";
// Data rows // Data rows (fix column ordering: include rb_count before Species)
for (const auto &metrics : all_stats) { for (const auto &metrics : all_stats) {
for (size_t cell_idx = 0; cell_idx < metrics.id.size(); ++cell_idx) { for (size_t cell_idx = 0; cell_idx < metrics.id.size(); ++cell_idx) {
for (size_t sp_idx = 0; sp_idx < species_names.size(); ++sp_idx) { for (size_t sp_idx = 0; sp_idx < species_names.size(); ++sp_idx) {
out << std::left << std::setw(15) << metrics.id[cell_idx] out << std::left << std::setw(15) << metrics.id[cell_idx]
<< std::setw(15) << metrics.iteration << std::setw(15) << std::setw(15) << metrics.iteration
<< species_names[sp_idx] << std::setw(15) << std::setw(15) << metrics.rb_count
<< metrics.mape[cell_idx][sp_idx] << std::setw(15) << std::setw(15) << species_names[sp_idx]
<< metrics.rrmse[cell_idx][sp_idx] << "\n"; << std::setw(15) << metrics.mape[cell_idx][sp_idx]
<< std::setw(15) << metrics.rrmse[cell_idx][sp_idx] << "\n";
} }
} }
out << "\n"; out << "\n";
@ -100,7 +108,7 @@ void writeCellStatsToCSV(const std::vector<poet::CellErrorMetrics> &all_stats,
} }
void writeSpeciesStatsToCSV( void writeSpeciesStatsToCSV(
const std::vector<poet::SpeciesErrorMetrics> &all_stats, const std::vector<poet::SpeciesMetrics> &all_stats,
const std::vector<std::string> &species_names, const std::string &out_dir, const std::vector<std::string> &species_names, const std::string &out_dir,
const std::string &filename) { const std::string &filename) {
std::ofstream out(std::filesystem::path(out_dir) / filename); std::ofstream out(std::filesystem::path(out_dir) / filename);
@ -120,7 +128,7 @@ void writeSpeciesStatsToCSV(
for (const auto &metrics : all_stats) { for (const auto &metrics : all_stats) {
for (size_t sp_idx = 0; sp_idx < species_names.size(); ++sp_idx) { for (size_t sp_idx = 0; sp_idx < species_names.size(); ++sp_idx) {
out << std::left << std::setw(15) << metrics.iteration << std::setw(15) out << std::left << std::setw(15) << metrics.iteration << std::setw(15)
<< metrics.rollback_count << std::setw(15) << species_names[sp_idx] << metrics.rb_count << std::setw(15) << species_names[sp_idx]
<< std::setw(15) << metrics.mape[sp_idx] << std::setw(15) << std::setw(15) << metrics.mape[sp_idx] << std::setw(15)
<< metrics.rrmse[sp_idx] << "\n"; << metrics.rrmse[sp_idx] << "\n";
} }

View File

@ -2,11 +2,11 @@
#include <vector> #include <vector>
void writeSpeciesStatsToCSV( void writeSpeciesStatsToCSV(
const std::vector<poet::SpeciesErrorMetrics> &all_stats, const std::vector<poet::SpeciesMetrics> &all_stats,
const std::vector<std::string> &species_names, const std::string &out_dir, const std::vector<std::string> &species_names, const std::string &out_dir,
const std::string &filename); const std::string &filename);
void writeCellStatsToCSV(const std::vector<poet::CellErrorMetrics> &all_stats, void writeCellStatsToCSV(const std::vector<poet::CellMetrics> &all_stats,
const std::vector<std::string> &species_names, const std::vector<std::string> &species_names,
const std::string &out_dir, const std::string &out_dir,
const std::string &filename); const std::string &filename);

View File

@ -99,9 +99,8 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
"Print progress bar during chemical simulation"); "Print progress bar during chemical simulation");
/*Parse work package size*/ /*Parse work package size*/
app.add_option( app.add_option("-w,--work-package-size", params.work_package_size,
"-w,--work-package-size", params.work_package_size, "Work package size to distribute to each worker for chemistry module")
"Work package size to distribute to each worker for chemistry module")
->check(CLI::PositiveNumber) ->check(CLI::PositiveNumber)
->default_val(RuntimeParameters::WORK_PACKAGE_SIZE_DEFAULT); ->default_val(RuntimeParameters::WORK_PACKAGE_SIZE_DEFAULT);
@ -112,21 +111,17 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
// cout << "CPP: DHT is " << ( dht_enabled ? "ON" : "OFF" ) << '\n'; // cout << "CPP: DHT is " << ( dht_enabled ? "ON" : "OFF" ) << '\n';
dht_group dht_group->add_option("--dht-size", params.dht_size, "DHT size per process in Megabyte")
->add_option("--dht-size", params.dht_size,
"DHT size per process in Megabyte")
->check(CLI::PositiveNumber) ->check(CLI::PositiveNumber)
->default_val(RuntimeParameters::DHT_SIZE_DEFAULT); ->default_val(RuntimeParameters::DHT_SIZE_DEFAULT);
// cout << "CPP: DHT size per process (Byte) = " << dht_size_per_process << // cout << "CPP: DHT size per process (Byte) = " << dht_size_per_process <<
// endl; // endl;
dht_group->add_option( dht_group->add_option("--dht-snaps", params.dht_snaps,
"--dht-snaps", params.dht_snaps, "Save snapshots of DHT to disk: \n0 = disabled (default)\n1 = After "
"Save snapshots of DHT to disk: \n0 = disabled (default)\n1 = After " "simulation\n2 = After each iteration");
"simulation\n2 = After each iteration");
auto *interp_group = auto *interp_group = app.add_option_group("Interpolation", "Interpolation related options");
app.add_option_group("Interpolation", "Interpolation related options");
interp_group->add_flag("--interp", params.use_interp, "Enable interpolation"); interp_group->add_flag("--interp", params.use_interp, "Enable interpolation");
interp_group interp_group
@ -140,38 +135,31 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
->check(CLI::PositiveNumber) ->check(CLI::PositiveNumber)
->default_val(RuntimeParameters::INTERP_MIN_ENTRIES_DEFAULT); ->default_val(RuntimeParameters::INTERP_MIN_ENTRIES_DEFAULT);
interp_group interp_group
->add_option( ->add_option("--interp-bucket-entries", params.interp_bucket_entries,
"--interp-bucket-entries", params.interp_bucket_entries, "Maximum number of entries in each bucket of the interpolation table")
"Maximum number of entries in each bucket of the interpolation table")
->check(CLI::PositiveNumber) ->check(CLI::PositiveNumber)
->default_val(RuntimeParameters::INTERP_BUCKET_ENTRIES_DEFAULT); ->default_val(RuntimeParameters::INTERP_BUCKET_ENTRIES_DEFAULT);
app.add_flag("--ai-surrogate", params.use_ai_surrogate, app.add_flag("--ai-surrogate", params.use_ai_surrogate,
"Enable AI surrogate for chemistry module"); "Enable AI surrogate for chemistry module");
app.add_flag("--rds", params.as_rds, app.add_flag("--rds", params.as_rds, "Save output as .rds file instead of default .qs2");
"Save output as .rds file instead of default .qs2");
app.add_flag("--qs", params.as_qs, app.add_flag("--qs", params.as_qs, "Save output as .qs file instead of default .qs2");
"Save output as .qs file instead of default .qs2");
std::string init_file; std::string init_file;
std::string runtime_file; std::string runtime_file;
app.add_option("runtime_file", runtime_file, app.add_option("runtime_file", runtime_file, "Runtime R script defining the simulation")
"Runtime R script defining the simulation")
->required() ->required()
->check(CLI::ExistingFile); ->check(CLI::ExistingFile);
app.add_option( app.add_option("init_file", init_file,
"init_file", init_file, "Initial R script defining the simulation, produced by poet_init")
"Initial R script defining the simulation, produced by poet_init")
->required() ->required()
->check(CLI::ExistingFile); ->check(CLI::ExistingFile);
app.add_option("out_dir", params.out_dir, app.add_option("out_dir", params.out_dir, "Output directory of the simulation")->required();
"Output directory of the simulation")
->required();
try { try {
app.parse(argc, argv); app.parse(argc, argv);
@ -202,8 +190,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
// << simparams.dht_significant_digits); // << simparams.dht_significant_digits);
// MSG("DHT logarithm before rounding: " // MSG("DHT logarithm before rounding: "
// << (simparams.dht_log ? "ON" : "OFF")); // << (simparams.dht_log ? "ON" : "OFF"));
MSG("DHT size per process (Megabyte) = " + MSG("DHT size per process (Megabyte) = " + std::to_string(params.dht_size));
std::to_string(params.dht_size));
MSG("DHT save snapshots is " + BOOL_PRINT(params.dht_snaps)); MSG("DHT save snapshots is " + BOOL_PRINT(params.dht_snaps));
// MSG("DHT load file is " + chem_params.dht_file); // MSG("DHT load file is " + chem_params.dht_file);
} }
@ -212,8 +199,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
MSG("PHT interpolation enabled: " + BOOL_PRINT(params.use_interp)); MSG("PHT interpolation enabled: " + BOOL_PRINT(params.use_interp));
MSG("PHT interp-size = " + std::to_string(params.interp_size)); MSG("PHT interp-size = " + std::to_string(params.interp_size));
MSG("PHT interp-min = " + std::to_string(params.interp_min_entries)); MSG("PHT interp-min = " + std::to_string(params.interp_min_entries));
MSG("PHT interp-bucket-entries = " + MSG("PHT interp-bucket-entries = " + std::to_string(params.interp_bucket_entries));
std::to_string(params.interp_bucket_entries));
} }
} }
// chem_params.dht_outdir = out_dir; // chem_params.dht_outdir = out_dir;
@ -248,17 +234,15 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
// MDL add "out_ext" for output format to R setup // MDL add "out_ext" for output format to R setup
(*global_rt_setup)["out_ext"] = params.out_ext; (*global_rt_setup)["out_ext"] = params.out_ext;
params.timesteps = params.timesteps = Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("timesteps"));
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("timesteps")); params.chkpt_interval = Rcpp::as<uint32_t>(global_rt_setup->operator[]("chkpt_interval"));
params.checkpoint_interval = params.rb_limit = Rcpp::as<uint32_t>(global_rt_setup->operator[]("rb_limit"));
Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval")); params.stab_interval = Rcpp::as<uint32_t>(global_rt_setup->operator[]("stab_interval"));
params.stab_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("stab_interval"));
params.zero_abs = Rcpp::as<double>(global_rt_setup->operator[]("zero_abs")); params.zero_abs = Rcpp::as<double>(global_rt_setup->operator[]("zero_abs"));
params.mape_threshold = Rcpp::as<std::vector<double>>( params.mape_threshold =
global_rt_setup->operator[]("mape_threshold")); Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("mape_threshold"));
params.ctrl_cell_ids = Rcpp::as<std::vector<uint32_t>>( params.ctrl_cell_ids =
global_rt_setup->operator[]("ctrl_cell_ids")); Rcpp::as<std::vector<uint32_t>>(global_rt_setup->operator[]("ctrl_cell_ids"));
} catch (const std::exception &e) { } catch (const std::exception &e) {
ERRMSG("Error while parsing R scripts: " + std::string(e.what())); ERRMSG("Error while parsing R scripts: " + std::string(e.what()));
@ -274,22 +258,20 @@ void call_master_iter_end(RInside &R, const Field &trans, const Field &chem) {
R["TMP"] = Rcpp::wrap(trans.AsVector()); R["TMP"] = Rcpp::wrap(trans.AsVector());
R["TMP_PROPS"] = Rcpp::wrap(trans.GetProps()); R["TMP_PROPS"] = Rcpp::wrap(trans.GetProps());
R.parseEval(std::string("state_T <- setNames(data.frame(matrix(TMP, nrow=" + R.parseEval(std::string("state_T <- setNames(data.frame(matrix(TMP, nrow=" +
std::to_string(trans.GetRequestedVecSize()) + std::to_string(trans.GetRequestedVecSize()) + ")), TMP_PROPS)"));
")), TMP_PROPS)"));
R["TMP"] = Rcpp::wrap(chem.AsVector()); R["TMP"] = Rcpp::wrap(chem.AsVector());
R["TMP_PROPS"] = Rcpp::wrap(chem.GetProps()); R["TMP_PROPS"] = Rcpp::wrap(chem.GetProps());
R.parseEval(std::string("state_C <- setNames(data.frame(matrix(TMP, nrow=" + R.parseEval(std::string("state_C <- setNames(data.frame(matrix(TMP, nrow=" +
std::to_string(chem.GetRequestedVecSize()) + std::to_string(chem.GetRequestedVecSize()) + ")), TMP_PROPS)"));
")), TMP_PROPS)"));
R["setup"] = *global_rt_setup; R["setup"] = *global_rt_setup;
R.parseEval("setup <- master_iteration_end(setup, state_T, state_C)"); R.parseEval("setup <- master_iteration_end(setup, state_T, state_C)");
*global_rt_setup = R["setup"]; *global_rt_setup = R["setup"];
} }
static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params, static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
DiffusionModule &diffusion, DiffusionModule &diffusion, ChemistryModule &chem,
ChemistryModule &chem, ControlModule &control) { ControlModule &control) {
/* Iteration Count is dynamic, retrieving value from R (is only needed by /* Iteration Count is dynamic, retrieving value from R (is only needed by
* master for the following loop) */ * master for the following loop) */
@ -313,8 +295,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
std::cout << std::endl; std::cout << std::endl;
/* displaying iteration number, with C++ and R iterator */ /* displaying iteration number, with C++ and R iterator */
MSG("Going through iteration " + std::to_string(iter) + "/" + MSG("Going through iteration " + std::to_string(iter) + "/" + std::to_string(maxiter));
std::to_string(maxiter));
MSG("Current time step is " + std::to_string(dt)); MSG("Current time step is " + std::to_string(dt));
@ -328,10 +309,9 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
double ai_start_t = MPI_Wtime(); double ai_start_t = MPI_Wtime();
// Save current values from the tug field as predictor for the ai step // Save current values from the tug field as predictor for the ai step
R["TMP"] = Rcpp::wrap(chem.getField().AsVector()); R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
R.parseEval( R.parseEval(std::string("predictors <- setNames(data.frame(matrix(TMP, nrow=" +
std::string("predictors <- setNames(data.frame(matrix(TMP, nrow=" + std::to_string(chem.getField().GetRequestedVecSize()) +
std::to_string(chem.getField().GetRequestedVecSize()) + ")), TMP_PROPS)"));
")), TMP_PROPS)"));
R.parseEval("predictors <- predictors[ai_surrogate_species]"); R.parseEval("predictors <- predictors[ai_surrogate_species]");
// Apply preprocessing // Apply preprocessing
@ -340,8 +320,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
// Predict // Predict
MSG("AI Prediction"); MSG("AI Prediction");
R.parseEval( R.parseEval("aipreds_scaled <- prediction_step(model, predictors_scaled)");
"aipreds_scaled <- prediction_step(model, predictors_scaled)");
// Apply postprocessing // Apply postprocessing
MSG("AI Postprocessing"); MSG("AI Postprocessing");
@ -349,22 +328,19 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
// Validate prediction and write valid predictions to chem field // Validate prediction and write valid predictions to chem field
MSG("AI Validation"); MSG("AI Validation");
R.parseEval( R.parseEval("validity_vector <- validate_predictions(predictors, aipreds)");
"validity_vector <- validate_predictions(predictors, aipreds)");
MSG("AI Marking accepted"); MSG("AI Marking accepted");
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector")); chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
MSG("AI TempField"); MSG("AI TempField");
std::vector<std::vector<double>> RTempField = std::vector<std::vector<double>> RTempField = R.parseEval("set_valid_predictions(predictors,\
R.parseEval("set_valid_predictions(predictors,\
aipreds,\ aipreds,\
validity_vector)"); validity_vector)");
MSG("AI Set Field"); MSG("AI Set Field");
Field predictions_field = Field predictions_field =
Field(R.parseEval("nrow(predictors)"), RTempField, Field(R.parseEval("nrow(predictors)"), RTempField, R.parseEval("colnames(predictors)"));
R.parseEval("colnames(predictors)"));
MSG("AI Update"); MSG("AI Update");
chem.getField().update(predictions_field); chem.getField().update(predictions_field);
@ -379,10 +355,9 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
double ai_start_t = MPI_Wtime(); double ai_start_t = MPI_Wtime();
R["TMP"] = Rcpp::wrap(chem.getField().AsVector()); R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
R.parseEval( R.parseEval(std::string("targets <- setNames(data.frame(matrix(TMP, nrow=" +
std::string("targets <- setNames(data.frame(matrix(TMP, nrow=" + std::to_string(chem.getField().GetRequestedVecSize()) +
std::to_string(chem.getField().GetRequestedVecSize()) + ")), TMP_PROPS)"));
")), TMP_PROPS)"));
R.parseEval("targets <- targets[ai_surrogate_species]"); R.parseEval("targets <- targets[ai_surrogate_species]");
// TODO: Check how to get the correct columns // TODO: Check how to get the correct columns
@ -411,12 +386,9 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
diffusion.getField().update(chem.getField()); diffusion.getField().update(chem.getField());
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" + MSG("End of *coupling* iteration " + std::to_string(iter) + "/" + std::to_string(maxiter));
std::to_string(maxiter));
control.writeErrorMetrics(iter, params.out_dir, chem.getField().GetProps()); control.processCheckpoint(diffusion, iter, params.out_dir, chem.getField().GetProps());
control.processCheckpoint(diffusion, iter, params.out_dir,
chem.getField().GetProps());
// MSG(); // MSG();
} // END SIMULATION LOOP } // END SIMULATION LOOP
@ -436,10 +408,10 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
Rcpp::List ctrl_profiling; Rcpp::List ctrl_profiling;
ctrl_profiling["compute_metrics_master"] = chem.GetMasterCtrlMetricsTime(); ctrl_profiling["compute_metrics_master"] = chem.GetMasterCtrlMetricsTime();
ctrl_profiling["unshuffle_field_master"] = chem.GetMasterUnshuffleTime(); ctrl_profiling["unshuffle_field_master"] = chem.GetMasterUnshuffleTime();
ctrl_profiling["w_checkpoint_master"] = control.getWriteCheckpointTime(); ctrl_profiling["w_checkpoint_master"] = control.getChkptWriteTime();
ctrl_profiling["r_checkpoint_master"] = control.getReadCheckpointTime(); ctrl_profiling["r_checkpoint_master"] = control.getChkptReadTime();
ctrl_profiling["write_stats"] = control.getWriteMetricsTime(); ctrl_profiling["write_stats"] = control.getMetricsWriteTime();
ctrl_profiling["ctrl_logic_master"] = control.getUpdateCtrlLogicTime(); ctrl_profiling["ctrl_logic_master"] = control.getCtrlLogicTime();
ctrl_profiling["recv_data_master"] = chem.GetMasterRecvCtrlDataTime(); ctrl_profiling["recv_data_master"] = chem.GetMasterRecvCtrlDataTime();
ctrl_profiling["worker"] = Rcpp::wrap(chem.GetWorkerControlTimings()); ctrl_profiling["worker"] = Rcpp::wrap(chem.GetWorkerControlTimings());
@ -451,16 +423,11 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
//} //}
if (params.use_interp) { if (params.use_interp) {
chem_profiling["interp_w"] = chem_profiling["interp_w"] = Rcpp::wrap(chem.GetWorkerInterpolationWriteTimings());
Rcpp::wrap(chem.GetWorkerInterpolationWriteTimings()); chem_profiling["interp_r"] = Rcpp::wrap(chem.GetWorkerInterpolationReadTimings());
chem_profiling["interp_r"] = chem_profiling["interp_g"] = Rcpp::wrap(chem.GetWorkerInterpolationGatherTimings());
Rcpp::wrap(chem.GetWorkerInterpolationReadTimings()); chem_profiling["interp_fc"] = Rcpp::wrap(chem.GetWorkerInterpolationFunctionCallTimings());
chem_profiling["interp_g"] = chem_profiling["interp_calls"] = Rcpp::wrap(chem.GetWorkerInterpolationCalls());
Rcpp::wrap(chem.GetWorkerInterpolationGatherTimings());
chem_profiling["interp_fc"] =
Rcpp::wrap(chem.GetWorkerInterpolationFunctionCallTimings());
chem_profiling["interp_calls"] =
Rcpp::wrap(chem.GetWorkerInterpolationCalls());
chem_profiling["interp_cached"] = Rcpp::wrap(chem.GetWorkerPHTCacheHits()); chem_profiling["interp_cached"] = Rcpp::wrap(chem.GetWorkerPHTCacheHits());
} }
@ -475,8 +442,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
return profiling; return profiling;
} }
static void getControlCellIds(std::vector<std::uint32_t> &ids, int root, static void getControlCellIds(std::vector<std::uint32_t> &ids, int root, MPI_Comm comm) {
MPI_Comm comm) {
std::uint32_t n_ids = 0; std::uint32_t n_ids = 0;
int rank; int rank;
MPI_Comm_rank(comm, &rank); MPI_Comm_rank(comm, &rank);
@ -498,8 +464,7 @@ static void getControlCellIds(std::vector<std::uint32_t> &ids, int root,
} }
} }
std::vector<std::string> getSpeciesNames(const Field &&field, int root, std::vector<std::string> getSpeciesNames(const Field &&field, int root, MPI_Comm comm) {
MPI_Comm comm) {
std::uint32_t n_elements; std::uint32_t n_elements;
std::uint32_t n_string_size; std::uint32_t n_string_size;
@ -516,8 +481,8 @@ std::vector<std::string> getSpeciesNames(const Field &&field, int root,
for (std::uint32_t i = 0; i < n_elements; i++) { for (std::uint32_t i = 0; i < n_elements; i++) {
n_string_size = field.GetProps()[i].size(); n_string_size = field.GetProps()[i].size();
MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD); MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
MPI_Bcast(const_cast<char *>(field.GetProps()[i].c_str()), n_string_size, MPI_Bcast(const_cast<char *>(field.GetProps()[i].c_str()), n_string_size, MPI_CHAR, root,
MPI_CHAR, root, MPI_COMM_WORLD); MPI_COMM_WORLD);
} }
return field.GetProps(); return field.GetProps();
@ -631,8 +596,8 @@ int main(int argc, char *argv[]) {
MPI_Barrier(MPI_COMM_WORLD); MPI_Barrier(MPI_COMM_WORLD);
ChemistryModule chemistry(run_params.work_package_size, ChemistryModule chemistry(run_params.work_package_size, init_list.getChemistryInit(),
init_list.getChemistryInit(), MPI_COMM_WORLD); MPI_COMM_WORLD);
const ChemistryModule::SurrogateSetup surr_setup = { const ChemistryModule::SurrogateSetup surr_setup = {
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD), getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
@ -661,8 +626,8 @@ int main(int argc, char *argv[]) {
// // if (MY_RANK == 0) { // get timestep vector from // // if (MY_RANK == 0) { // get timestep vector from
// // grid_init function ... // // // grid_init function ... //
*global_rt_setup = master_init_R(*global_rt_setup, run_params.out_dir, *global_rt_setup =
init_list.getInitialGrid().asSEXP()); master_init_R(*global_rt_setup, run_params.out_dir, init_list.getInitialGrid().asSEXP());
// MDL: store all parameters // MDL: store all parameters
// MSG("Calling R Function to store calling parameters"); // MSG("Calling R Function to store calling parameters");
@ -674,8 +639,7 @@ int main(int argc, char *argv[]) {
/* Incorporate ai surrogate from R */ /* Incorporate ai surrogate from R */
R.parseEvalQ(ai_surrogate_r_library); R.parseEvalQ(ai_surrogate_r_library);
/* Use dht species for model input and output */ /* Use dht species for model input and output */
R["ai_surrogate_species"] = R["ai_surrogate_species"] = init_list.getChemistryInit().dht_species.getNames();
init_list.getChemistryInit().dht_species.getNames();
const std::string ai_surrogate_input_script = const std::string ai_surrogate_input_script =
init_list.getChemistryInit().ai_surrogate_input_script; init_list.getChemistryInit().ai_surrogate_input_script;
@ -692,20 +656,17 @@ int main(int argc, char *argv[]) {
// MPI_Barrier(MPI_COMM_WORLD); // MPI_Barrier(MPI_COMM_WORLD);
DiffusionModule diffusion(init_list.getDiffusionInit(), DiffusionModule diffusion(init_list.getDiffusionInit(), init_list.getInitialGrid());
init_list.getInitialGrid());
ControlConfig config(run_params.stab_interval, ControlConfig config(run_params.stab_interval, run_params.chkpt_interval, run_params.rb_limit,
run_params.checkpoint_interval, run_params.zero_abs, run_params.zero_abs, run_params.mape_threshold);
run_params.mape_threshold);
ControlModule control(config); ControlModule control(config);
chemistry.masterSetField(init_list.getInitialGrid()); chemistry.masterSetField(init_list.getInitialGrid());
chemistry.SetControlModule(&control); chemistry.SetControlModule(&control);
Rcpp::List profiling = Rcpp::List profiling = RunMasterLoop(R, run_params, diffusion, chemistry, control);
RunMasterLoop(R, run_params, diffusion, chemistry, control);
MSG("finished simulation loop"); MSG("finished simulation loop");
@ -718,8 +679,8 @@ int main(int argc, char *argv[]) {
"'/timings.', setup$out_ext));"; "'/timings.', setup$out_ext));";
R.parseEval(r_vis_code); R.parseEval(r_vis_code);
MSG("Done! Results are stored as R objects into <" + run_params.out_dir + MSG("Done! Results are stored as R objects into <" + run_params.out_dir + "/timings." +
"/timings." + run_params.out_ext); run_params.out_ext);
} }
} }

View File

@ -51,8 +51,8 @@ struct RuntimeParameters {
bool print_progress = false; bool print_progress = false;
std::uint32_t stab_interval = 0; std::uint32_t stab_interval = 0;
std::uint32_t checkpoint_interval = 0; std::uint32_t chkpt_interval = 0;
std::uint32_t max_rb = 0; std::uint32_t rb_limit = 0;
double zero_abs = 0.0; double zero_abs = 0.0;
std::vector<double> mape_threshold; std::vector<double> mape_threshold;
std::vector<uint32_t> ctrl_cell_ids; std::vector<uint32_t> ctrl_cell_ids;