mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-16 12:54:50 +01:00
feat: Add control logic for surrogate validation with selective PHREEQC re-execution, error metrics (MAPE/RRMSE), and automatic checkpoint/rollback
This commit is contained in:
parent
1b2d942960
commit
6c5b86cccc
9
.gitignore
vendored
9
.gitignore
vendored
@ -152,9 +152,12 @@ lib/
|
|||||||
include/
|
include/
|
||||||
|
|
||||||
# But keep these specific files
|
# But keep these specific files
|
||||||
!bin/barite_fgcs_2.pqi
|
!bin/compare_qs2.R
|
||||||
!bin/barite_fgcs_2.qs2
|
!bin/barite_fgcs_3.pqi
|
||||||
!bin/barite_fgcs_2.R
|
!bin/barite_fgcs_4_rt.R
|
||||||
|
!bin/barite_fgcs_4.R
|
||||||
|
!bin/barite_fgcs_4.qs2
|
||||||
|
!bin/db_barite.dat
|
||||||
!bin/dol.pqi
|
!bin/dol.pqi
|
||||||
!bin/dolo_fgcs_3.qs2
|
!bin/dolo_fgcs_3.qs2
|
||||||
!bin/dolo_fgcs_3.R
|
!bin/dolo_fgcs_3.R
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -33,6 +33,7 @@ add_library(POETLib
|
|||||||
Chemistry/SurrogateModels/HashFunctions.cpp
|
Chemistry/SurrogateModels/HashFunctions.cpp
|
||||||
Chemistry/SurrogateModels/InterpolationModule.cpp
|
Chemistry/SurrogateModels/InterpolationModule.cpp
|
||||||
Chemistry/SurrogateModels/ProximityHashTable.cpp
|
Chemistry/SurrogateModels/ProximityHashTable.cpp
|
||||||
|
Control/ControlModule.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
set(POET_TUG_APPROACH "Implicit" CACHE STRING "tug numerical approach to use")
|
set(POET_TUG_APPROACH "Implicit" CACHE STRING "tug numerical approach to use")
|
||||||
|
|||||||
@ -18,6 +18,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <mpi.h>
|
#include <mpi.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace poet {
|
namespace poet {
|
||||||
@ -259,13 +260,17 @@ public:
|
|||||||
|
|
||||||
void SetControlModule(poet::ControlModule *ctrl) { control_module = ctrl; }
|
void SetControlModule(poet::ControlModule *ctrl) { control_module = ctrl; }
|
||||||
|
|
||||||
void SetDhtEnabled(bool enabled) { dht_enabled = enabled; }
|
void SetDhtEnabled(bool enabled) { this->dht_enabled = enabled; }
|
||||||
bool GetDhtEnabled() const { return dht_enabled; }
|
bool GetDhtEnabled() const { return this->dht_enabled; }
|
||||||
|
|
||||||
void SetInterpEnabled(bool enabled) { interp_enabled = enabled; }
|
void SetInterpEnabled(bool enabled) { this->interp_enabled = enabled; }
|
||||||
bool GetInterpEnabled() const { return interp_enabled; }
|
bool GetInterpEnabled() const { return interp_enabled; }
|
||||||
|
|
||||||
void SetWarmupEnabled(bool enabled) { warmup_enabled = enabled; }
|
void SetWarmupEnabled(bool enabled) { this->warmup_enabled = enabled; }
|
||||||
|
|
||||||
|
void SetControlCellIds(const std::vector<uint32_t> &ids) {
|
||||||
|
this->ctrl_cell_ids = std::unordered_set<uint32_t>(ids.begin(), ids.end());
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void initializeDHT(uint32_t size_mb,
|
void initializeDHT(uint32_t size_mb,
|
||||||
@ -280,13 +285,13 @@ protected:
|
|||||||
|
|
||||||
enum {
|
enum {
|
||||||
CHEM_FIELD_INIT,
|
CHEM_FIELD_INIT,
|
||||||
//CHEM_DHT_ENABLE,
|
// CHEM_DHT_ENABLE,
|
||||||
CHEM_DHT_SIGNIF_VEC,
|
CHEM_DHT_SIGNIF_VEC,
|
||||||
CHEM_DHT_SNAPS,
|
CHEM_DHT_SNAPS,
|
||||||
CHEM_DHT_READ_FILE,
|
CHEM_DHT_READ_FILE,
|
||||||
//CHEM_WARMUP_PHASE, // Control flag
|
// CHEM_WARMUP_PHASE, // Control flag
|
||||||
//CHEM_CTRL_ENABLE, // Control flag
|
// CHEM_CTRL_ENABLE, // Control flag
|
||||||
//CHEM_IP_ENABLE,
|
// CHEM_IP_ENABLE,
|
||||||
CHEM_IP_MIN_ENTRIES,
|
CHEM_IP_MIN_ENTRIES,
|
||||||
CHEM_IP_SIGNIF_VEC,
|
CHEM_IP_SIGNIF_VEC,
|
||||||
CHEM_WORK_LOOP,
|
CHEM_WORK_LOOP,
|
||||||
@ -323,6 +328,7 @@ protected:
|
|||||||
double dht_fill = 0.;
|
double dht_fill = 0.;
|
||||||
double idle_t = 0.;
|
double idle_t = 0.;
|
||||||
double ctrl_t = 0.;
|
double ctrl_t = 0.;
|
||||||
|
double ctrl_phreeqc_t = 0.;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct worker_info_s {
|
struct worker_info_s {
|
||||||
@ -364,6 +370,10 @@ protected:
|
|||||||
void WorkerRunWorkPackage(WorkPackage &work_package, double dSimTime,
|
void WorkerRunWorkPackage(WorkPackage &work_package, double dSimTime,
|
||||||
double dTimestep);
|
double dTimestep);
|
||||||
|
|
||||||
|
void ProcessControlWorkPackage(std::vector<std::vector<double>> &input,
|
||||||
|
double current_sim_time, double dt,
|
||||||
|
struct worker_s &timings);
|
||||||
|
|
||||||
std::vector<uint32_t> CalculateWPSizesVector(uint32_t n_cells,
|
std::vector<uint32_t> CalculateWPSizesVector(uint32_t n_cells,
|
||||||
uint32_t wp_size) const;
|
uint32_t wp_size) const;
|
||||||
std::vector<double> shuffleField(const std::vector<double> &in_field,
|
std::vector<double> shuffleField(const std::vector<double> &in_field,
|
||||||
@ -444,7 +454,8 @@ protected:
|
|||||||
bool control_enabled{false};
|
bool control_enabled{false};
|
||||||
bool warmup_enabled{false};
|
bool warmup_enabled{false};
|
||||||
|
|
||||||
// std::vector<double> sur_shuffled;
|
std::unordered_set<uint32_t> ctrl_cell_ids;
|
||||||
|
std::vector<std::vector<double>> control_batch;
|
||||||
};
|
};
|
||||||
} // namespace poet
|
} // namespace poet
|
||||||
|
|
||||||
|
|||||||
@ -341,32 +341,33 @@ inline void poet::ChemistryModule::MasterRecvPkgs(worker_list_t &w_list,
|
|||||||
MPI_Get_count(&probe_status, MPI_DOUBLE, &size);
|
MPI_Get_count(&probe_status, MPI_DOUBLE, &size);
|
||||||
MPI_Recv(w_list[p - 1].send_addr, size, MPI_DOUBLE, p, LOOP_WORK,
|
MPI_Recv(w_list[p - 1].send_addr, size, MPI_DOUBLE, p, LOOP_WORK,
|
||||||
this->group_comm, MPI_STATUS_IGNORE);
|
this->group_comm, MPI_STATUS_IGNORE);
|
||||||
handled = true;
|
// Only LOOP_WORK completes a work package
|
||||||
|
w_list[p - 1].has_work = 0;
|
||||||
|
pkg_to_recv -= 1;
|
||||||
|
free_workers++;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case LOOP_CTRL: {
|
case LOOP_CTRL: {
|
||||||
recv_ctrl_a = MPI_Wtime();
|
recv_ctrl_a = MPI_Wtime();
|
||||||
/* layout of buffer is [phreeqc][surrogate] */
|
|
||||||
MPI_Get_count(&probe_status, MPI_DOUBLE, &size);
|
MPI_Get_count(&probe_status, MPI_DOUBLE, &size);
|
||||||
recv_buffer.resize(size);
|
recv_buffer.resize(size);
|
||||||
MPI_Recv(recv_buffer.data(), size, MPI_DOUBLE, p, LOOP_CTRL,
|
MPI_Recv(recv_buffer.data(), size, MPI_DOUBLE, p, LOOP_CTRL,
|
||||||
this->group_comm, MPI_STATUS_IGNORE);
|
this->group_comm, MPI_STATUS_IGNORE);
|
||||||
|
|
||||||
int half = size / 2;
|
|
||||||
std::copy(recv_buffer.begin(), recv_buffer.begin() + half,
|
|
||||||
w_list[p - 1].send_addr);
|
|
||||||
|
|
||||||
/*
|
|
||||||
if (w_list[p - 1].surrogate_addr == nullptr) {
|
|
||||||
throw std::runtime_error("MasterRecvPkgs: surrogate_addr is null");
|
|
||||||
}*/
|
|
||||||
|
|
||||||
std::copy(recv_buffer.begin() + (size / 2), recv_buffer.begin() + size,
|
|
||||||
w_list[p - 1].surrogate_addr);
|
|
||||||
recv_ctrl_b = MPI_Wtime();
|
recv_ctrl_b = MPI_Wtime();
|
||||||
recv_ctrl_t += recv_ctrl_b - recv_ctrl_a;
|
recv_ctrl_t += recv_ctrl_b - recv_ctrl_a;
|
||||||
|
|
||||||
handled = true;
|
// Collect PHREEQC rows for control cells
|
||||||
|
const std::size_t cells_per_batch =
|
||||||
|
static_cast<std::size_t>(size) /
|
||||||
|
static_cast<std::size_t>(this->prop_count);
|
||||||
|
for (std::size_t i = 0; i < cells_per_batch; i++) {
|
||||||
|
std::vector<double> cell_output(
|
||||||
|
recv_buffer.begin() + this->prop_count * i,
|
||||||
|
recv_buffer.begin() + this->prop_count * (i + 1));
|
||||||
|
this->control_batch.push_back(std::move(cell_output));
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default: {
|
default: {
|
||||||
@ -374,11 +375,6 @@ inline void poet::ChemistryModule::MasterRecvPkgs(worker_list_t &w_list,
|
|||||||
std::to_string(probe_status.MPI_TAG));
|
std::to_string(probe_status.MPI_TAG));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (handled) {
|
|
||||||
w_list[p - 1].has_work = 0;
|
|
||||||
pkg_to_recv -= 1;
|
|
||||||
free_workers++;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -450,11 +446,6 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
|||||||
|
|
||||||
MPI_Barrier(this->group_comm);
|
MPI_Barrier(this->group_comm);
|
||||||
|
|
||||||
this->control_enabled = this->control_module->getControlIntervalEnabled();
|
|
||||||
if (this->control_enabled) {
|
|
||||||
this->mpi_surr_buffer.assign(this->n_cells * this->prop_count, 0.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
static uint32_t iteration = 0;
|
static uint32_t iteration = 0;
|
||||||
|
|
||||||
/* start time measurement of sequential part */
|
/* start time measurement of sequential part */
|
||||||
@ -466,16 +457,14 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
|||||||
shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count,
|
shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count,
|
||||||
wp_sizes_vector.size());
|
wp_sizes_vector.size());
|
||||||
|
|
||||||
//this->mpi_surr_buffer.resize(mpi_buffer.size());
|
// this->mpi_surr_buffer.resize(mpi_buffer.size());
|
||||||
|
|
||||||
/* setup local variables */
|
/* setup local variables */
|
||||||
pkg_to_send = wp_sizes_vector.size();
|
pkg_to_send = wp_sizes_vector.size();
|
||||||
pkg_to_recv = wp_sizes_vector.size();
|
pkg_to_recv = wp_sizes_vector.size();
|
||||||
|
|
||||||
workpointer_t work_pointer = mpi_buffer.begin();
|
workpointer_t work_pointer = mpi_buffer.begin();
|
||||||
workpointer_t sur_pointer = this->mpi_surr_buffer.begin();
|
workpointer_t sur_pointer = mpi_buffer.begin();
|
||||||
//(this->control_enabled ? this->mpi_surr_buffer.begin()
|
|
||||||
// : mpi_buffer.end());
|
|
||||||
worker_list_t worker_list(this->comm_size - 1);
|
worker_list_t worker_list(this->comm_size - 1);
|
||||||
|
|
||||||
free_workers = this->comm_size - 1;
|
free_workers = this->comm_size - 1;
|
||||||
@ -523,28 +512,36 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
|||||||
chem_field = out_vec;
|
chem_field = out_vec;
|
||||||
|
|
||||||
/* do master stuff */
|
/* do master stuff */
|
||||||
if (this->control_enabled) {
|
if (!this->control_batch.empty()) {
|
||||||
std::cout << "[Master] Control logic enabled for this iteration."
|
std::cout << "[Master] Processing " << this->control_batch.size()
|
||||||
<< std::endl;
|
<< " control cells for comparison." << std::endl;
|
||||||
std::vector<double> sur_unshuffled{mpi_surr_buffer};
|
|
||||||
|
|
||||||
shuf_a = MPI_Wtime();
|
/* using mpi-buffer because we need cell-major layout*/
|
||||||
unshuffleField(this->mpi_surr_buffer, this->n_cells, this->prop_count,
|
std::vector<std::vector<double>> surrogate_batch;
|
||||||
wp_sizes_vector.size(), sur_unshuffled);
|
surrogate_batch.reserve(this->control_batch.size());
|
||||||
shuf_b = MPI_Wtime();
|
|
||||||
this->shuf_t += shuf_b - shuf_a;
|
|
||||||
|
|
||||||
size_t N = out_vec.size();
|
for (const auto &element : this->control_batch) {
|
||||||
if (N != sur_unshuffled.size()) {
|
|
||||||
std::cerr << "[MASTER DBG] size mismatch out_vec=" << N
|
for (size_t i = 0; i < this->n_cells; i++) {
|
||||||
<< " sur_unshuffled=" << sur_unshuffled.size() << std::endl;
|
uint32_t curr_cell_id = mpi_buffer[this->prop_count * i];
|
||||||
|
|
||||||
|
if (curr_cell_id == element[0]) {
|
||||||
|
std::vector<double> surrogate_output(
|
||||||
|
mpi_buffer.begin() + this->prop_count * i,
|
||||||
|
mpi_buffer.begin() + this->prop_count * (i + 1));
|
||||||
|
surrogate_batch.push_back(surrogate_output);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
metrics_a = MPI_Wtime();
|
metrics_a = MPI_Wtime();
|
||||||
control_module->computeSpeciesErrorMetrics(out_vec, sur_unshuffled,
|
control_module->computeSpeciesErrorMetrics(this->control_batch, surrogate_batch, 1);
|
||||||
this->n_cells);
|
|
||||||
metrics_b = MPI_Wtime();
|
metrics_b = MPI_Wtime();
|
||||||
this->metrics_t += metrics_b - metrics_a;
|
this->metrics_t += metrics_b - metrics_a;
|
||||||
|
|
||||||
|
// Clear for next control iteration
|
||||||
|
this->control_batch.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* start time measurement of master chemistry */
|
/* start time measurement of master chemistry */
|
||||||
|
|||||||
@ -59,37 +59,6 @@ void poet::ChemistryModule::WorkerLoop() {
|
|||||||
MPI_INT, 0, this->group_comm);
|
MPI_INT, 0, this->group_comm);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
/*
|
|
||||||
case CHEM_WARMUP_PHASE: {
|
|
||||||
int warmup_flag = 0;
|
|
||||||
ChemBCast(&warmup_flag, 1, MPI_INT);
|
|
||||||
this->warmup_enabled = (warmup_flag == 1);
|
|
||||||
//std::cout << "Warmup phase is " << this->warmup_enabled << std::endl;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case CHEM_DHT_ENABLE: {
|
|
||||||
int dht_flag = 0;
|
|
||||||
ChemBCast(&dht_flag, 1, MPI_INT);
|
|
||||||
this->dht_enabled = (dht_flag == 1);
|
|
||||||
//std::cout << "DHT_enabled is " << this->dht_enabled << std::endl;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case CHEM_IP_ENABLE: {
|
|
||||||
int interp_flag = 0;
|
|
||||||
ChemBCast(&interp_flag, 1, MPI_INT);
|
|
||||||
this->interp_enabled = (interp_flag == 1);
|
|
||||||
;
|
|
||||||
std::cout << "Interp_enabled is " << this->interp_enabled << std::endl;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case CHEM_CTRL_ENABLE: {
|
|
||||||
int control_flag = 0;
|
|
||||||
ChemBCast(&control_flag, 1, MPI_INT);
|
|
||||||
this->control_enabled = (control_flag == 1);
|
|
||||||
std::cout << "Control_enabled is " << this->control_enabled << std::endl;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
case CHEM_WORK_LOOP: {
|
case CHEM_WORK_LOOP: {
|
||||||
WorkerProcessPkgs(timings, iteration);
|
WorkerProcessPkgs(timings, iteration);
|
||||||
break;
|
break;
|
||||||
@ -147,6 +116,36 @@ void poet::ChemistryModule::WorkerProcessPkgs(struct worker_s &timings,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void poet::ChemistryModule::ProcessControlWorkPackage(
|
||||||
|
std::vector<std::vector<double>> &input, double current_sim_time, double dt,
|
||||||
|
struct worker_s &timings) {
|
||||||
|
|
||||||
|
double phreeqc_start, phreeqc_end;
|
||||||
|
|
||||||
|
if (input.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
poet::WorkPackage control_wp(input.size());
|
||||||
|
control_wp.input = input;
|
||||||
|
std::vector<double> mpi_buffer(control_wp.size * this->prop_count);
|
||||||
|
|
||||||
|
phreeqc_start = MPI_Wtime();
|
||||||
|
WorkerRunWorkPackage(control_wp, current_sim_time, dt);
|
||||||
|
phreeqc_end = MPI_Wtime();
|
||||||
|
|
||||||
|
timings.ctrl_phreeqc_t += phreeqc_end - phreeqc_start;
|
||||||
|
|
||||||
|
for (std::size_t wp_i = 0; wp_i < control_wp.size; wp_i++) {
|
||||||
|
std::copy(control_wp.output[wp_i].begin(), control_wp.output[wp_i].end(),
|
||||||
|
mpi_buffer.begin() + this->prop_count * wp_i);
|
||||||
|
}
|
||||||
|
|
||||||
|
MPI_Request send_req;
|
||||||
|
MPI_Isend(mpi_buffer.data(), mpi_buffer.size(), MPI_DOUBLE, 0, LOOP_CTRL,
|
||||||
|
MPI_COMM_WORLD, &send_req);
|
||||||
|
MPI_Wait(&send_req, MPI_STATUS_IGNORE);
|
||||||
|
}
|
||||||
|
|
||||||
void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||||
int double_count,
|
int double_count,
|
||||||
struct worker_s &timings) {
|
struct worker_s &timings) {
|
||||||
@ -165,6 +164,9 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
|||||||
int flags;
|
int flags;
|
||||||
std::vector<double> mpi_buffer(count);
|
std::vector<double> mpi_buffer(count);
|
||||||
|
|
||||||
|
static int control_cells_processed = 0;
|
||||||
|
static std::vector<std::vector<double>> control_batch;
|
||||||
|
|
||||||
/* receive */
|
/* receive */
|
||||||
MPI_Recv(mpi_buffer.data(), count, MPI_DOUBLE, 0, LOOP_WORK, this->group_comm,
|
MPI_Recv(mpi_buffer.data(), count, MPI_DOUBLE, 0, LOOP_WORK, this->group_comm,
|
||||||
MPI_STATUS_IGNORE);
|
MPI_STATUS_IGNORE);
|
||||||
@ -234,86 +236,45 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* if control iteration: create copy surrogate results (output and mappings)
|
/* process cells to be monitored in a seperate workpackage */
|
||||||
and then set them to zero, give this to phreeqc */
|
|
||||||
|
|
||||||
|
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
|
||||||
|
uint32_t cell_id = s_curr_wp.input[wp_i][0];
|
||||||
|
if (this->ctrl_cell_ids.find(cell_id) != this->ctrl_cell_ids.end() &&
|
||||||
|
s_curr_wp.mapping[wp_i] != CHEM_PQC) {
|
||||||
|
|
||||||
|
control_batch.push_back(s_curr_wp.input[wp_i]);
|
||||||
|
control_cells_processed++;
|
||||||
|
|
||||||
poet::WorkPackage s_curr_wp_control = s_curr_wp;
|
if (control_batch.size() == s_curr_wp.size ||
|
||||||
|
control_cells_processed == this->ctrl_cell_ids.size()) {
|
||||||
|
ProcessControlWorkPackage(control_batch, current_sim_time, dt, timings);
|
||||||
|
control_batch.clear();
|
||||||
/*
|
control_cells_processed = 0;
|
||||||
if (control_enabled) {
|
}
|
||||||
ctrl_cp_start = MPI_Wtime();
|
|
||||||
for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) {
|
|
||||||
s_curr_wp_control.output[wp_i] = std::vector<double>(this->prop_count, 0.0);
|
|
||||||
s_curr_wp_control.mapping[wp_i] = CHEM_PQC;
|
|
||||||
}
|
}
|
||||||
ctrl_cp_end = MPI_Wtime();
|
|
||||||
timings.ctrl_t += ctrl_cp_end - ctrl_cp_start;
|
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
phreeqc_time_start = MPI_Wtime();
|
phreeqc_time_start = MPI_Wtime();
|
||||||
|
|
||||||
WorkerRunWorkPackage(control_enabled ? s_curr_wp_control : s_curr_wp,
|
WorkerRunWorkPackage(s_curr_wp, current_sim_time, dt);
|
||||||
current_sim_time, dt);
|
|
||||||
|
|
||||||
phreeqc_time_end = MPI_Wtime();
|
phreeqc_time_end = MPI_Wtime();
|
||||||
|
|
||||||
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
|
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
|
||||||
std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(),
|
std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(),
|
||||||
mpi_buffer.begin() + this->prop_count * wp_i);
|
mpi_buffer.begin() + this->prop_count * wp_i);
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
if (control_enabled) {
|
|
||||||
ctrl_start = MPI_Wtime();
|
|
||||||
std::size_t sur_wp_offset = s_curr_wp.size * this->prop_count;
|
|
||||||
|
|
||||||
mpi_buffer.resize(count + sur_wp_offset);
|
|
||||||
|
|
||||||
for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) {
|
|
||||||
std::copy(s_curr_wp_control.output[wp_i].begin(),
|
|
||||||
s_curr_wp_control.output[wp_i].end(),
|
|
||||||
mpi_buffer.begin() + this->prop_count * wp_i);
|
|
||||||
}
|
|
||||||
|
|
||||||
// s_curr_wp only contains the interpolated data
|
|
||||||
// copy surrogate output after the the pqc output, mpi_buffer[pqc][interp]
|
|
||||||
|
|
||||||
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
|
|
||||||
// only copy if surrogate was used
|
|
||||||
if (s_curr_wp.mapping[wp_i] != CHEM_PQC) {
|
|
||||||
std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(),
|
|
||||||
mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i);
|
|
||||||
} else {
|
|
||||||
// if pqc was used, copy pqc results again
|
|
||||||
std::copy(s_curr_wp_control.output[wp_i].begin(),
|
|
||||||
s_curr_wp_control.output[wp_i].end(),
|
|
||||||
mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
count += sur_wp_offset;
|
|
||||||
ctrl_end = MPI_Wtime();
|
|
||||||
timings.ctrl_t += ctrl_end - ctrl_start;
|
|
||||||
} else {
|
|
||||||
|
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
/* send results to master */
|
/* send results to master */
|
||||||
MPI_Request send_req;
|
MPI_Request send_req;
|
||||||
|
MPI_Isend(mpi_buffer.data(), count, MPI_DOUBLE, 0, LOOP_WORK, MPI_COMM_WORLD,
|
||||||
int mpi_tag = control_enabled ? LOOP_CTRL : LOOP_WORK;
|
|
||||||
MPI_Isend(mpi_buffer.data(), count, MPI_DOUBLE, 0, mpi_tag, MPI_COMM_WORLD,
|
|
||||||
&send_req);
|
&send_req);
|
||||||
|
|
||||||
if (dht_enabled || interp_enabled || warmup_enabled) {
|
if (dht_enabled || interp_enabled || warmup_enabled) {
|
||||||
/* write results to DHT */
|
/* write results to DHT */
|
||||||
dht_fill_start = MPI_Wtime();
|
dht_fill_start = MPI_Wtime();
|
||||||
dht->fillDHT(control_enabled ? s_curr_wp_control : s_curr_wp);
|
dht->fillDHT(s_curr_wp);
|
||||||
dht_fill_end = MPI_Wtime();
|
dht_fill_end = MPI_Wtime();
|
||||||
|
|
||||||
if (interp_enabled || warmup_enabled) {
|
if (interp_enabled || warmup_enabled) {
|
||||||
|
|||||||
@ -13,20 +13,10 @@ void poet::ControlModule::updateControlIteration(const uint32_t &iter,
|
|||||||
double prep_a, prep_b;
|
double prep_a, prep_b;
|
||||||
|
|
||||||
prep_a = MPI_Wtime();
|
prep_a = MPI_Wtime();
|
||||||
if (control_interval == 0) {
|
|
||||||
control_interval_enabled = false;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
global_iteration = iter;
|
global_iteration = iter;
|
||||||
initiateWarmupPhase(dht_enabled, interp_enabled);
|
initiateWarmupPhase(dht_enabled, interp_enabled);
|
||||||
|
|
||||||
control_interval_enabled =
|
|
||||||
(control_interval > 0 && iter % control_interval == 0);
|
|
||||||
|
|
||||||
if (control_interval_enabled) {
|
|
||||||
MSG("[Control] Control interval enabled at iteration " +
|
|
||||||
std::to_string(iter));
|
|
||||||
}
|
|
||||||
prep_b = MPI_Wtime();
|
prep_b = MPI_Wtime();
|
||||||
this->prep_t += prep_b - prep_a;
|
this->prep_t += prep_b - prep_a;
|
||||||
}
|
}
|
||||||
@ -35,20 +25,19 @@ void poet::ControlModule::initiateWarmupPhase(bool dht_enabled,
|
|||||||
bool interp_enabled) {
|
bool interp_enabled) {
|
||||||
|
|
||||||
// user requested DHT/INTEP? keep them disabled but enable warmup-phase so
|
// user requested DHT/INTEP? keep them disabled but enable warmup-phase so
|
||||||
if (global_iteration <= control_interval || rollback_enabled) {
|
if (rollback_enabled) {
|
||||||
chem->SetWarmupEnabled(true);
|
chem->SetWarmupEnabled(true);
|
||||||
chem->SetDhtEnabled(false);
|
chem->SetDhtEnabled(false);
|
||||||
chem->SetInterpEnabled(false);
|
chem->SetInterpEnabled(false);
|
||||||
MSG("Warmup enabled until next control interval at iteration " +
|
|
||||||
std::to_string(control_interval) + ".");
|
|
||||||
|
|
||||||
if (rollback_enabled) {
|
MSG("Warmup enabled until next control interval at iteration " +
|
||||||
if (sur_disabled_counter > 0) {
|
std::to_string(penalty_interval) + ".");
|
||||||
--sur_disabled_counter;
|
|
||||||
MSG("Rollback counter: " + std::to_string(sur_disabled_counter));
|
if (sur_disabled_counter > 0) {
|
||||||
} else {
|
--sur_disabled_counter;
|
||||||
rollback_enabled = false;
|
MSG("Rollback counter: " + std::to_string(sur_disabled_counter));
|
||||||
}
|
} else {
|
||||||
|
rollback_enabled = false;
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -58,23 +47,23 @@ void poet::ControlModule::initiateWarmupPhase(bool dht_enabled,
|
|||||||
chem->SetInterpEnabled(interp_enabled);
|
chem->SetInterpEnabled(interp_enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
void poet::ControlModule::applyControlLogic(ChemistryModule &chem,
|
void poet::ControlModule::applyControlLogic(DiffusionModule &diffusion,
|
||||||
uint32_t &iter) {
|
uint32_t &iter) {
|
||||||
if (!control_interval_enabled) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
writeCheckpointAndMetrics(chem, iter);
|
|
||||||
|
|
||||||
if (checkAndRollback(chem, iter) && rollback_count < 4) {
|
writeCheckpointAndMetrics(diffusion, iter);
|
||||||
|
|
||||||
|
if (checkAndRollback(diffusion, iter) && rollback_count < 3) {
|
||||||
rollback_enabled = true;
|
rollback_enabled = true;
|
||||||
rollback_count++;
|
rollback_count++;
|
||||||
sur_disabled_counter = control_interval;
|
sur_disabled_counter = penalty_interval;
|
||||||
|
|
||||||
MSG("Interpolation disabled for the next " +
|
MSG("Interpolation disabled for the next " +
|
||||||
std::to_string(control_interval) + ".");
|
std::to_string(penalty_interval) + ".");
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void poet::ControlModule::writeCheckpointAndMetrics(ChemistryModule &chem,
|
void poet::ControlModule::writeCheckpointAndMetrics(DiffusionModule &diffusion,
|
||||||
uint32_t iter) {
|
uint32_t iter) {
|
||||||
|
|
||||||
double w_check_a, w_check_b, stats_a, stats_b;
|
double w_check_a, w_check_b, stats_a, stats_b;
|
||||||
@ -82,7 +71,7 @@ void poet::ControlModule::writeCheckpointAndMetrics(ChemistryModule &chem,
|
|||||||
|
|
||||||
w_check_a = MPI_Wtime();
|
w_check_a = MPI_Wtime();
|
||||||
write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
|
write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
|
||||||
{.field = chem.getField(), .iteration = iter});
|
{.field = diffusion.getField(), .iteration = iter});
|
||||||
w_check_b = MPI_Wtime();
|
w_check_b = MPI_Wtime();
|
||||||
this->w_check_t += w_check_b - w_check_a;
|
this->w_check_t += w_check_b - w_check_a;
|
||||||
|
|
||||||
@ -93,7 +82,7 @@ void poet::ControlModule::writeCheckpointAndMetrics(ChemistryModule &chem,
|
|||||||
this->stats_t += stats_b - stats_a;
|
this->stats_t += stats_b - stats_a;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool poet::ControlModule::checkAndRollback(ChemistryModule &chem,
|
bool poet::ControlModule::checkAndRollback(DiffusionModule &diffusion,
|
||||||
uint32_t &iter) {
|
uint32_t &iter) {
|
||||||
double r_check_a, r_check_b;
|
double r_check_a, r_check_b;
|
||||||
|
|
||||||
@ -119,7 +108,7 @@ bool poet::ControlModule::checkAndRollback(ChemistryModule &chem,
|
|||||||
" → rolling back to iteration " + std::to_string(rollback_iter));
|
" → rolling back to iteration " + std::to_string(rollback_iter));
|
||||||
|
|
||||||
r_check_a = MPI_Wtime();
|
r_check_a = MPI_Wtime();
|
||||||
Checkpoint_s checkpoint_read{.field = chem.getField()};
|
Checkpoint_s checkpoint_read{.field = diffusion.getField()};
|
||||||
read_checkpoint(out_dir,
|
read_checkpoint(out_dir,
|
||||||
"checkpoint" + std::to_string(rollback_iter) + ".hdf5",
|
"checkpoint" + std::to_string(rollback_iter) + ".hdf5",
|
||||||
checkpoint_read);
|
checkpoint_read);
|
||||||
@ -135,8 +124,9 @@ bool poet::ControlModule::checkAndRollback(ChemistryModule &chem,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void poet::ControlModule::computeSpeciesErrorMetrics(
|
void poet::ControlModule::computeSpeciesErrorMetrics(
|
||||||
const std::vector<double> &reference_values,
|
std::vector<std::vector<double>> &reference_values,
|
||||||
const std::vector<double> &surrogate_values, const uint32_t size_per_prop) {
|
std::vector<std::vector<double>> &surrogate_values,
|
||||||
|
const uint32_t size_per_prop) {
|
||||||
|
|
||||||
SpeciesErrorMetrics metrics(this->species_names.size(), global_iteration,
|
SpeciesErrorMetrics metrics(this->species_names.size(), global_iteration,
|
||||||
rollback_count);
|
rollback_count);
|
||||||
@ -148,25 +138,16 @@ void poet::ControlModule::computeSpeciesErrorMetrics(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::size_t expected =
|
// Loop over species (rows in the data structure)
|
||||||
static_cast<std::size_t>(this->species_names.size()) * size_per_prop;
|
for (size_t species_idx = 0; species_idx < reference_values.size(); species_idx++) {
|
||||||
if (reference_values.size() < expected) {
|
|
||||||
std::cerr << "[CTRL ERROR] input vectors too small: expected >= "
|
|
||||||
<< expected << " entries, got " << reference_values.size()
|
|
||||||
<< "\n";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (uint32_t i = 0; i < this->species_names.size(); ++i) {
|
|
||||||
double err_sum = 0.0;
|
double err_sum = 0.0;
|
||||||
double sqr_err_sum = 0.0;
|
double sqr_err_sum = 0.0;
|
||||||
uint32_t base_idx = i * size_per_prop;
|
uint32_t count = 0;
|
||||||
|
|
||||||
int count = 0;
|
// Loop over control cells (columns in the data structure)
|
||||||
|
for (size_t cell_idx = 0; cell_idx < size_per_prop; cell_idx++) {
|
||||||
for (uint32_t j = 0; j < size_per_prop; ++j) {
|
const double ref_value = reference_values[species_idx][cell_idx];
|
||||||
const double ref_value = reference_values[base_idx + j];
|
const double sur_value = surrogate_values[species_idx][cell_idx];
|
||||||
const double sur_value = surrogate_values[base_idx + j];
|
|
||||||
const double ZERO_ABS = 1e-13;
|
const double ZERO_ABS = 1e-13;
|
||||||
|
|
||||||
if (std::isnan(ref_value) || std::isnan(sur_value)) {
|
if (std::isnan(ref_value) || std::isnan(sur_value)) {
|
||||||
@ -177,17 +158,28 @@ void poet::ControlModule::computeSpeciesErrorMetrics(
|
|||||||
if (std::abs(sur_value) >= ZERO_ABS) {
|
if (std::abs(sur_value) >= ZERO_ABS) {
|
||||||
err_sum += 1.0;
|
err_sum += 1.0;
|
||||||
sqr_err_sum += 1.0;
|
sqr_err_sum += 1.0;
|
||||||
|
count++;
|
||||||
}
|
}
|
||||||
|
// Both zero: skip (don't increment count)
|
||||||
}
|
}
|
||||||
// Both zero: skip
|
|
||||||
else {
|
else {
|
||||||
double alpha = 1.0 - (sur_value / ref_value);
|
double alpha = 1.0 - (sur_value / ref_value);
|
||||||
err_sum += std::abs(alpha);
|
err_sum += std::abs(alpha);
|
||||||
sqr_err_sum += alpha * alpha;
|
sqr_err_sum += alpha * alpha;
|
||||||
|
count++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
metrics.mape[i] = 100.0 * (err_sum / size_per_prop);
|
|
||||||
metrics.rrmse[i] = std::sqrt(sqr_err_sum / size_per_prop);
|
// Store metrics for this species after processing all cells
|
||||||
|
if (count > 0) {
|
||||||
|
metrics.mape[species_idx] = 100.0 * (err_sum / size_per_prop);
|
||||||
|
metrics.rrmse[species_idx] = std::sqrt(sqr_err_sum / size_per_prop);
|
||||||
|
} else {
|
||||||
|
metrics.mape[species_idx] = 0.0;
|
||||||
|
metrics.rrmse[species_idx] = 0.0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Push metrics to history once after processing all species
|
||||||
metricsHistory.push_back(metrics);
|
metricsHistory.push_back(metrics);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
#include "Base/Macros.hpp"
|
#include "Base/Macros.hpp"
|
||||||
#include "Chemistry/ChemistryModule.hpp"
|
#include "Chemistry/ChemistryModule.hpp"
|
||||||
|
#include "Transport/DiffusionModule.hpp"
|
||||||
#include "poet.hpp"
|
#include "poet.hpp"
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
@ -12,6 +13,7 @@
|
|||||||
namespace poet {
|
namespace poet {
|
||||||
|
|
||||||
class ChemistryModule;
|
class ChemistryModule;
|
||||||
|
class DiffusionModule;
|
||||||
|
|
||||||
class ControlModule {
|
class ControlModule {
|
||||||
|
|
||||||
@ -27,7 +29,7 @@ public:
|
|||||||
|
|
||||||
void initiateWarmupPhase(bool dht_enabled, bool interp_enabled);
|
void initiateWarmupPhase(bool dht_enabled, bool interp_enabled);
|
||||||
|
|
||||||
bool checkAndRollback(ChemistryModule &chem, uint32_t &iter);
|
bool checkAndRollback(DiffusionModule &diffusion, uint32_t &iter);
|
||||||
|
|
||||||
struct SpeciesErrorMetrics {
|
struct SpeciesErrorMetrics {
|
||||||
std::vector<double> mape;
|
std::vector<double> mape;
|
||||||
@ -40,16 +42,16 @@ public:
|
|||||||
rollback_count(counter) {}
|
rollback_count(counter) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
void computeSpeciesErrorMetrics(const std::vector<double> &reference_values,
|
void computeSpeciesErrorMetrics(
|
||||||
const std::vector<double> &surrogate_values,
|
std::vector<std::vector<double>> &reference_values,
|
||||||
const uint32_t size_per_prop);
|
std::vector<std::vector<double>> &surrogate_values,
|
||||||
|
const uint32_t size_per_prop);
|
||||||
|
|
||||||
std::vector<SpeciesErrorMetrics> metricsHistory;
|
std::vector<SpeciesErrorMetrics> metricsHistory;
|
||||||
|
|
||||||
struct ControlSetup {
|
struct ControlSetup {
|
||||||
std::string out_dir;
|
std::string out_dir;
|
||||||
std::uint32_t checkpoint_interval;
|
std::uint32_t checkpoint_interval;
|
||||||
std::uint32_t control_interval;
|
|
||||||
std::vector<std::string> species_names;
|
std::vector<std::string> species_names;
|
||||||
std::vector<double> mape_threshold;
|
std::vector<double> mape_threshold;
|
||||||
std::vector<uint32_t> ctrl_cell_ids;
|
std::vector<uint32_t> ctrl_cell_ids;
|
||||||
@ -58,26 +60,19 @@ public:
|
|||||||
void enableControlLogic(const ControlSetup &setup) {
|
void enableControlLogic(const ControlSetup &setup) {
|
||||||
this->out_dir = setup.out_dir;
|
this->out_dir = setup.out_dir;
|
||||||
this->checkpoint_interval = setup.checkpoint_interval;
|
this->checkpoint_interval = setup.checkpoint_interval;
|
||||||
this->control_interval = setup.control_interval;
|
|
||||||
this->species_names = setup.species_names;
|
this->species_names = setup.species_names;
|
||||||
this->mape_threshold = setup.mape_threshold;
|
this->mape_threshold = setup.mape_threshold;
|
||||||
this->ctrl_cell_ids = setup.ctrl_cell_ids;
|
this->ctrl_cell_ids = setup.ctrl_cell_ids;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool getControlIntervalEnabled() const {
|
void applyControlLogic(DiffusionModule &diffusion, uint32_t &iter);
|
||||||
return this->control_interval_enabled;
|
|
||||||
}
|
|
||||||
|
|
||||||
void applyControlLogic(ChemistryModule &chem, uint32_t &iter);
|
void writeCheckpointAndMetrics(DiffusionModule &diffusion, uint32_t iter);
|
||||||
|
|
||||||
void writeCheckpointAndMetrics(ChemistryModule &chem, uint32_t iter);
|
|
||||||
|
|
||||||
auto getGlobalIteration() const noexcept { return global_iteration; }
|
auto getGlobalIteration() const noexcept { return global_iteration; }
|
||||||
|
|
||||||
void setChemistryModule(poet::ChemistryModule *c) { chem = c; }
|
void setChemistryModule(poet::ChemistryModule *c) { chem = c; }
|
||||||
|
|
||||||
auto getControlInterval() const { return this->control_interval; }
|
|
||||||
|
|
||||||
std::vector<double> getMapeThreshold() const { return this->mape_threshold; }
|
std::vector<double> getMapeThreshold() const { return this->mape_threshold; }
|
||||||
|
|
||||||
std::vector<uint32_t> getCtrlCellIds() const { return this->ctrl_cell_ids; }
|
std::vector<uint32_t> getCtrlCellIds() const { return this->ctrl_cell_ids; }
|
||||||
@ -94,12 +89,11 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
bool rollback_enabled = false;
|
bool rollback_enabled = false;
|
||||||
bool control_interval_enabled = false;
|
|
||||||
|
|
||||||
poet::ChemistryModule *chem = nullptr;
|
poet::ChemistryModule *chem = nullptr;
|
||||||
|
|
||||||
|
std::uint32_t penalty_interval = 50;
|
||||||
std::uint32_t checkpoint_interval = 0;
|
std::uint32_t checkpoint_interval = 0;
|
||||||
std::uint32_t control_interval = 0;
|
|
||||||
std::uint32_t global_iteration = 0;
|
std::uint32_t global_iteration = 0;
|
||||||
std::uint32_t rollback_count = 0;
|
std::uint32_t rollback_count = 0;
|
||||||
std::uint32_t sur_disabled_counter = 0;
|
std::uint32_t sur_disabled_counter = 0;
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include "Datatypes.hpp"
|
#include "Datatypes.hpp"
|
||||||
|
|
||||||
int write_checkpoint(const std::string &file_path, struct Checkpoint_s &&checkpoint);
|
|
||||||
|
|
||||||
int read_checkpoint(const std::string &file_path, struct Checkpoint_s &checkpoint);
|
int write_checkpoint(const std::string &dir_path, const std::string &file_name, struct Checkpoint_s &&checkpoint);
|
||||||
|
|
||||||
|
int read_checkpoint(const std::string &dir_path, const std::string &file_name, struct Checkpoint_s &checkpoint);
|
||||||
|
|||||||
@ -1,13 +1,20 @@
|
|||||||
#include "IO/Datatypes.hpp"
|
#include "IO/Datatypes.hpp"
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <highfive/H5Easy.hpp>
|
#include <highfive/H5Easy.hpp>
|
||||||
|
#include <filesystem>
|
||||||
|
|
||||||
int write_checkpoint(const std::string &file_path, struct Checkpoint_s &&checkpoint){
|
namespace fs = std::filesystem;
|
||||||
|
|
||||||
|
int write_checkpoint(const std::string &dir_path, const std::string &file_name, struct Checkpoint_s &&checkpoint){
|
||||||
|
|
||||||
|
if (!fs::exists(dir_path)) {
|
||||||
|
std::cerr << "Directory does not exist: " << dir_path << std::endl;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
fs::path file_path = fs::path(dir_path) / file_name;
|
||||||
// TODO: errorhandling
|
// TODO: errorhandling
|
||||||
H5Easy::File file(file_path, H5Easy::File::Overwrite);
|
H5Easy::File file(file_path, H5Easy::File::Overwrite);
|
||||||
|
|
||||||
|
|
||||||
H5Easy::dump(file, "/MetaParam/Iterations", checkpoint.iteration);
|
H5Easy::dump(file, "/MetaParam/Iterations", checkpoint.iteration);
|
||||||
H5Easy::dump(file, "/Grid/Names", checkpoint.field.GetProps());
|
H5Easy::dump(file, "/Grid/Names", checkpoint.field.GetProps());
|
||||||
H5Easy::dump(file, "/Grid/Chemistry", checkpoint.field.As2DVector());
|
H5Easy::dump(file, "/Grid/Chemistry", checkpoint.field.As2DVector());
|
||||||
@ -15,9 +22,16 @@ int write_checkpoint(const std::string &file_path, struct Checkpoint_s &&checkpo
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int read_checkpoint(const std::string &file_path, struct Checkpoint_s &checkpoint){
|
int read_checkpoint(const std::string &dir_path, const std::string &file_name, struct Checkpoint_s &checkpoint){
|
||||||
|
|
||||||
H5Easy::File file(file_path, H5Easy::File::ReadOnly);
|
fs::path file_path = fs::path(dir_path) / file_name;
|
||||||
|
|
||||||
|
if (!fs::exists(file_path)) {
|
||||||
|
std::cerr << "File does not exist: " << file_path << std::endl;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
H5Easy::File file(file_path, H5Easy::File::ReadOnly);
|
||||||
|
|
||||||
checkpoint.iteration = H5Easy::load<uint32_t>(file, "/MetaParam/Iterations");
|
checkpoint.iteration = H5Easy::load<uint32_t>(file, "/MetaParam/Iterations");
|
||||||
|
|
||||||
|
|||||||
795
src/poet.cpp
795
src/poet.cpp
@ -252,476 +252,473 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) {
|
|||||||
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("timesteps"));
|
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("timesteps"));
|
||||||
params.checkpoint_interval =
|
params.checkpoint_interval =
|
||||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval"));
|
Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval"));
|
||||||
params.control_interval =
|
|
||||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("control_interval"));
|
|
||||||
params.mape_threshold = Rcpp::as<std::vector<double>>(
|
params.mape_threshold = Rcpp::as<std::vector<double>>(
|
||||||
global_rt_setup->operator[]("mape_threshold"));
|
global_rt_setup->operator[]("mape_threshold"));
|
||||||
params.ctrl_cell_ids = Rcpp::as<std::vector<uint32_t>>(
|
params.ctrl_cell_ids = Rcpp::as<std::vector<uint32_t>>(
|
||||||
global_rt_setup->operator[]("ctrl_cell_ids"));
|
global_rt_setup->operator[]("ctrl_cell_ids"));
|
||||||
|
|
||||||
catch (const std::exception &e) {
|
} catch (const std::exception &e) {
|
||||||
ERRMSG("Error while parsing R scripts: " + std::string(e.what()));
|
ERRMSG("Error while parsing R scripts: " + std::string(e.what()));
|
||||||
return ParseRet::PARSER_ERROR;
|
return ParseRet::PARSER_ERROR;
|
||||||
}
|
|
||||||
|
|
||||||
return ParseRet::PARSER_OK;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// HACK: this is a step back as the order and also the count of fields is
|
return ParseRet::PARSER_OK;
|
||||||
// predefined, but it will change in the future
|
}
|
||||||
void call_master_iter_end(RInside & R, const Field &trans,
|
|
||||||
const Field &chem) {
|
|
||||||
R["TMP"] = Rcpp::wrap(trans.AsVector());
|
|
||||||
R["TMP_PROPS"] = Rcpp::wrap(trans.GetProps());
|
|
||||||
R.parseEval(std::string("state_T <- setNames(data.frame(matrix(TMP, nrow=" +
|
|
||||||
std::to_string(trans.GetRequestedVecSize()) +
|
|
||||||
")), TMP_PROPS)"));
|
|
||||||
|
|
||||||
R["TMP"] = Rcpp::wrap(chem.AsVector());
|
// HACK: this is a step back as the order and also the count of fields is
|
||||||
R["TMP_PROPS"] = Rcpp::wrap(chem.GetProps());
|
// predefined, but it will change in the future
|
||||||
R.parseEval(std::string("state_C <- setNames(data.frame(matrix(TMP, nrow=" +
|
void call_master_iter_end(RInside &R, const Field &trans, const Field &chem) {
|
||||||
std::to_string(chem.GetRequestedVecSize()) +
|
R["TMP"] = Rcpp::wrap(trans.AsVector());
|
||||||
")), TMP_PROPS)"));
|
R["TMP_PROPS"] = Rcpp::wrap(trans.GetProps());
|
||||||
R["setup"] = *global_rt_setup;
|
R.parseEval(std::string("state_T <- setNames(data.frame(matrix(TMP, nrow=" +
|
||||||
R.parseEval("setup <- master_iteration_end(setup, state_T, state_C)");
|
std::to_string(trans.GetRequestedVecSize()) +
|
||||||
*global_rt_setup = R["setup"];
|
")), TMP_PROPS)"));
|
||||||
|
|
||||||
|
R["TMP"] = Rcpp::wrap(chem.AsVector());
|
||||||
|
R["TMP_PROPS"] = Rcpp::wrap(chem.GetProps());
|
||||||
|
R.parseEval(std::string("state_C <- setNames(data.frame(matrix(TMP, nrow=" +
|
||||||
|
std::to_string(chem.GetRequestedVecSize()) +
|
||||||
|
")), TMP_PROPS)"));
|
||||||
|
R["setup"] = *global_rt_setup;
|
||||||
|
R.parseEval("setup <- master_iteration_end(setup, state_T, state_C)");
|
||||||
|
*global_rt_setup = R["setup"];
|
||||||
|
}
|
||||||
|
|
||||||
|
static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||||
|
DiffusionModule &diffusion,
|
||||||
|
ChemistryModule &chem, ControlModule &control) {
|
||||||
|
|
||||||
|
/* Iteration Count is dynamic, retrieving value from R (is only needed by
|
||||||
|
* master for the following loop) */
|
||||||
|
uint32_t maxiter = params.timesteps.size();
|
||||||
|
|
||||||
|
if (params.print_progress) {
|
||||||
|
chem.setProgressBarPrintout(true);
|
||||||
}
|
}
|
||||||
|
R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps());
|
||||||
|
|
||||||
static Rcpp::List RunMasterLoop(
|
/* SIMULATION LOOP */
|
||||||
RInsidePOET & R, RuntimeParameters & params, DiffusionModule & diffusion,
|
|
||||||
ChemistryModule & chem, ControlModule & control) {
|
|
||||||
|
|
||||||
/* Iteration Count is dynamic, retrieving value from R (is only needed by
|
double dSimTime{0};
|
||||||
* master for the following loop) */
|
for (uint32_t iter = 1; iter < maxiter + 1; iter++) {
|
||||||
uint32_t maxiter = params.timesteps.size();
|
control.updateControlIteration(iter, params.use_dht, params.use_interp);
|
||||||
|
|
||||||
if (params.print_progress) {
|
double start_t = MPI_Wtime();
|
||||||
chem.setProgressBarPrintout(true);
|
|
||||||
}
|
|
||||||
R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps());
|
|
||||||
|
|
||||||
/* SIMULATION LOOP */
|
const double &dt = params.timesteps[iter - 1];
|
||||||
|
|
||||||
double dSimTime{0};
|
|
||||||
for (uint32_t iter = 1; iter < maxiter + 1; iter++) {
|
|
||||||
control.updateControlIteration(iter, params.use_dht, params.use_interp);
|
|
||||||
|
|
||||||
double start_t = MPI_Wtime();
|
|
||||||
|
|
||||||
const double &dt = params.timesteps[iter - 1];
|
|
||||||
|
|
||||||
std::cout << std::endl;
|
|
||||||
|
|
||||||
/* displaying iteration number, with C++ and R iterator */
|
|
||||||
MSG("Going through iteration " + std::to_string(iter) + "/" +
|
|
||||||
std::to_string(maxiter));
|
|
||||||
|
|
||||||
MSG("Current time step is " + std::to_string(dt));
|
|
||||||
|
|
||||||
/* run transport */
|
|
||||||
diffusion.simulate(dt);
|
|
||||||
|
|
||||||
chem.getField().update(diffusion.getField());
|
|
||||||
|
|
||||||
// MSG("Chemistry start");
|
|
||||||
if (params.use_ai_surrogate) {
|
|
||||||
double ai_start_t = MPI_Wtime();
|
|
||||||
// Save current values from the tug field as predictor for the ai step
|
|
||||||
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
|
|
||||||
R.parseEval(
|
|
||||||
std::string("predictors <- setNames(data.frame(matrix(TMP, nrow=" +
|
|
||||||
std::to_string(chem.getField().GetRequestedVecSize()) +
|
|
||||||
")), TMP_PROPS)"));
|
|
||||||
R.parseEval("predictors <- predictors[ai_surrogate_species]");
|
|
||||||
|
|
||||||
// Apply preprocessing
|
|
||||||
MSG("AI Preprocessing");
|
|
||||||
R.parseEval("predictors_scaled <- preprocess(predictors)");
|
|
||||||
|
|
||||||
// Predict
|
|
||||||
MSG("AI Prediction");
|
|
||||||
R.parseEval(
|
|
||||||
"aipreds_scaled <- prediction_step(model, predictors_scaled)");
|
|
||||||
|
|
||||||
// Apply postprocessing
|
|
||||||
MSG("AI Postprocessing");
|
|
||||||
R.parseEval("aipreds <- postprocess(aipreds_scaled)");
|
|
||||||
|
|
||||||
// Validate prediction and write valid predictions to chem field
|
|
||||||
MSG("AI Validation");
|
|
||||||
R.parseEval(
|
|
||||||
"validity_vector <- validate_predictions(predictors, aipreds)");
|
|
||||||
|
|
||||||
MSG("AI Marking accepted");
|
|
||||||
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
|
|
||||||
|
|
||||||
MSG("AI TempField");
|
|
||||||
std::vector<std::vector<double>> RTempField =
|
|
||||||
R.parseEval("set_valid_predictions(predictors,\
|
|
||||||
aipreds,\
|
|
||||||
validity_vector)");
|
|
||||||
|
|
||||||
MSG("AI Set Field");
|
|
||||||
Field predictions_field =
|
|
||||||
Field(R.parseEval("nrow(predictors)"), RTempField,
|
|
||||||
R.parseEval("colnames(predictors)"));
|
|
||||||
|
|
||||||
MSG("AI Update");
|
|
||||||
chem.getField().update(predictions_field);
|
|
||||||
double ai_end_t = MPI_Wtime();
|
|
||||||
R["ai_prediction_time"] = ai_end_t - ai_start_t;
|
|
||||||
}
|
|
||||||
|
|
||||||
chem.simulate(dt);
|
|
||||||
|
|
||||||
/* AI surrogate iterative training*/
|
|
||||||
if (params.use_ai_surrogate) {
|
|
||||||
double ai_start_t = MPI_Wtime();
|
|
||||||
|
|
||||||
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
|
|
||||||
R.parseEval(
|
|
||||||
std::string("targets <- setNames(data.frame(matrix(TMP, nrow=" +
|
|
||||||
std::to_string(chem.getField().GetRequestedVecSize()) +
|
|
||||||
")), TMP_PROPS)"));
|
|
||||||
R.parseEval("targets <- targets[ai_surrogate_species]");
|
|
||||||
|
|
||||||
// TODO: Check how to get the correct columns
|
|
||||||
R.parseEval("target_scaled <- preprocess(targets)");
|
|
||||||
|
|
||||||
MSG("AI: incremental training");
|
|
||||||
R.parseEval("model <- training_step(model, predictors_scaled, "
|
|
||||||
"target_scaled, validity_vector)");
|
|
||||||
double ai_end_t = MPI_Wtime();
|
|
||||||
R["ai_training_time"] = ai_end_t - ai_start_t;
|
|
||||||
}
|
|
||||||
|
|
||||||
// MPI_Barrier(MPI_COMM_WORLD);
|
|
||||||
double end_t = MPI_Wtime();
|
|
||||||
dSimTime += end_t - start_t;
|
|
||||||
R["totaltime"] = dSimTime;
|
|
||||||
|
|
||||||
// MDL master_iteration_end just writes on disk state_T and
|
|
||||||
// state_C after every iteration if the cmdline option
|
|
||||||
// --ignore-results is not given (and thus the R variable
|
|
||||||
// store_result is TRUE)
|
|
||||||
call_master_iter_end(R, diffusion.getField(), chem.getField());
|
|
||||||
|
|
||||||
// TODO: write checkpoint
|
|
||||||
// checkpoint struct --> field and iteration
|
|
||||||
|
|
||||||
diffusion.getField().update(chem.getField());
|
|
||||||
|
|
||||||
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +
|
|
||||||
std::to_string(maxiter));
|
|
||||||
|
|
||||||
control.applyControlLogic(chem, iter);
|
|
||||||
// MSG();
|
|
||||||
} // END SIMULATION LOOP
|
|
||||||
|
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
|
|
||||||
Rcpp::List chem_profiling;
|
/* displaying iteration number, with C++ and R iterator */
|
||||||
chem_profiling["simtime"] = chem.GetChemistryTime();
|
MSG("Going through iteration " + std::to_string(iter) + "/" +
|
||||||
chem_profiling["loop"] = chem.GetMasterLoopTime();
|
std::to_string(maxiter));
|
||||||
chem_profiling["sequential"] = chem.GetMasterSequentialTime();
|
|
||||||
chem_profiling["idle_master"] = chem.GetMasterIdleTime();
|
|
||||||
chem_profiling["idle_worker"] = Rcpp::wrap(chem.GetWorkerIdleTimings());
|
|
||||||
chem_profiling["phreeqc_time"] = Rcpp::wrap(chem.GetWorkerPhreeqcTimings());
|
|
||||||
|
|
||||||
Rcpp::List diffusion_profiling;
|
MSG("Current time step is " + std::to_string(dt));
|
||||||
diffusion_profiling["simtime"] = diffusion.getTransportTime();
|
|
||||||
|
|
||||||
if (params.use_dht) {
|
/* run transport */
|
||||||
chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits());
|
diffusion.simulate(dt);
|
||||||
chem_profiling["dht_evictions"] =
|
|
||||||
Rcpp::wrap(chem.GetWorkerDHTEvictions());
|
chem.getField().update(diffusion.getField());
|
||||||
chem_profiling["dht_get_time"] =
|
|
||||||
Rcpp::wrap(chem.GetWorkerDHTGetTimings());
|
// MSG("Chemistry start");
|
||||||
chem_profiling["dht_fill_time"] =
|
if (params.use_ai_surrogate) {
|
||||||
Rcpp::wrap(chem.GetWorkerDHTFillTimings());
|
double ai_start_t = MPI_Wtime();
|
||||||
|
// Save current values from the tug field as predictor for the ai step
|
||||||
|
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
|
||||||
|
R.parseEval(
|
||||||
|
std::string("predictors <- setNames(data.frame(matrix(TMP, nrow=" +
|
||||||
|
std::to_string(chem.getField().GetRequestedVecSize()) +
|
||||||
|
")), TMP_PROPS)"));
|
||||||
|
R.parseEval("predictors <- predictors[ai_surrogate_species]");
|
||||||
|
|
||||||
|
// Apply preprocessing
|
||||||
|
MSG("AI Preprocessing");
|
||||||
|
R.parseEval("predictors_scaled <- preprocess(predictors)");
|
||||||
|
|
||||||
|
// Predict
|
||||||
|
MSG("AI Prediction");
|
||||||
|
R.parseEval(
|
||||||
|
"aipreds_scaled <- prediction_step(model, predictors_scaled)");
|
||||||
|
|
||||||
|
// Apply postprocessing
|
||||||
|
MSG("AI Postprocessing");
|
||||||
|
R.parseEval("aipreds <- postprocess(aipreds_scaled)");
|
||||||
|
|
||||||
|
// Validate prediction and write valid predictions to chem field
|
||||||
|
MSG("AI Validation");
|
||||||
|
R.parseEval(
|
||||||
|
"validity_vector <- validate_predictions(predictors, aipreds)");
|
||||||
|
|
||||||
|
MSG("AI Marking accepted");
|
||||||
|
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
|
||||||
|
|
||||||
|
MSG("AI TempField");
|
||||||
|
std::vector<std::vector<double>> RTempField =
|
||||||
|
R.parseEval("set_valid_predictions(predictors,\
|
||||||
|
aipreds,\
|
||||||
|
validity_vector)");
|
||||||
|
|
||||||
|
MSG("AI Set Field");
|
||||||
|
Field predictions_field =
|
||||||
|
Field(R.parseEval("nrow(predictors)"), RTempField,
|
||||||
|
R.parseEval("colnames(predictors)"));
|
||||||
|
|
||||||
|
MSG("AI Update");
|
||||||
|
chem.getField().update(predictions_field);
|
||||||
|
double ai_end_t = MPI_Wtime();
|
||||||
|
R["ai_prediction_time"] = ai_end_t - ai_start_t;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.use_interp) {
|
chem.simulate(dt);
|
||||||
chem_profiling["interp_w"] =
|
|
||||||
Rcpp::wrap(chem.GetWorkerInterpolationWriteTimings());
|
/* AI surrogate iterative training*/
|
||||||
chem_profiling["interp_r"] =
|
if (params.use_ai_surrogate) {
|
||||||
Rcpp::wrap(chem.GetWorkerInterpolationReadTimings());
|
double ai_start_t = MPI_Wtime();
|
||||||
chem_profiling["interp_g"] =
|
|
||||||
Rcpp::wrap(chem.GetWorkerInterpolationGatherTimings());
|
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
|
||||||
chem_profiling["interp_fc"] =
|
R.parseEval(
|
||||||
Rcpp::wrap(chem.GetWorkerInterpolationFunctionCallTimings());
|
std::string("targets <- setNames(data.frame(matrix(TMP, nrow=" +
|
||||||
chem_profiling["interp_calls"] =
|
std::to_string(chem.getField().GetRequestedVecSize()) +
|
||||||
Rcpp::wrap(chem.GetWorkerInterpolationCalls());
|
")), TMP_PROPS)"));
|
||||||
chem_profiling["interp_cached"] =
|
R.parseEval("targets <- targets[ai_surrogate_species]");
|
||||||
Rcpp::wrap(chem.GetWorkerPHTCacheHits());
|
|
||||||
|
// TODO: Check how to get the correct columns
|
||||||
|
R.parseEval("target_scaled <- preprocess(targets)");
|
||||||
|
|
||||||
|
MSG("AI: incremental training");
|
||||||
|
R.parseEval("model <- training_step(model, predictors_scaled, "
|
||||||
|
"target_scaled, validity_vector)");
|
||||||
|
double ai_end_t = MPI_Wtime();
|
||||||
|
R["ai_training_time"] = ai_end_t - ai_start_t;
|
||||||
}
|
}
|
||||||
|
|
||||||
Rcpp::List profiling;
|
// MPI_Barrier(MPI_COMM_WORLD);
|
||||||
profiling["simtime"] = dSimTime;
|
double end_t = MPI_Wtime();
|
||||||
profiling["chemistry"] = chem_profiling;
|
dSimTime += end_t - start_t;
|
||||||
profiling["diffusion"] = diffusion_profiling;
|
R["totaltime"] = dSimTime;
|
||||||
|
|
||||||
chem.MasterLoopBreak();
|
// MDL master_iteration_end just writes on disk state_T and
|
||||||
|
// state_C after every iteration if the cmdline option
|
||||||
|
// --ignore-results is not given (and thus the R variable
|
||||||
|
// store_result is TRUE)
|
||||||
|
call_master_iter_end(R, diffusion.getField(), chem.getField());
|
||||||
|
|
||||||
return profiling;
|
// TODO: write checkpoint
|
||||||
|
// checkpoint struct --> field and iteration
|
||||||
|
|
||||||
|
diffusion.getField().update(chem.getField());
|
||||||
|
|
||||||
|
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +
|
||||||
|
std::to_string(maxiter));
|
||||||
|
|
||||||
|
control.applyControlLogic(diffusion, iter);
|
||||||
|
// MSG();
|
||||||
|
} // END SIMULATION LOOP
|
||||||
|
|
||||||
|
std::cout << std::endl;
|
||||||
|
|
||||||
|
Rcpp::List chem_profiling;
|
||||||
|
chem_profiling["simtime"] = chem.GetChemistryTime();
|
||||||
|
chem_profiling["loop"] = chem.GetMasterLoopTime();
|
||||||
|
chem_profiling["sequential"] = chem.GetMasterSequentialTime();
|
||||||
|
chem_profiling["idle_master"] = chem.GetMasterIdleTime();
|
||||||
|
chem_profiling["idle_worker"] = Rcpp::wrap(chem.GetWorkerIdleTimings());
|
||||||
|
chem_profiling["phreeqc_time"] = Rcpp::wrap(chem.GetWorkerPhreeqcTimings());
|
||||||
|
|
||||||
|
Rcpp::List diffusion_profiling;
|
||||||
|
diffusion_profiling["simtime"] = diffusion.getTransportTime();
|
||||||
|
|
||||||
|
if (params.use_dht) {
|
||||||
|
chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits());
|
||||||
|
chem_profiling["dht_evictions"] = Rcpp::wrap(chem.GetWorkerDHTEvictions());
|
||||||
|
chem_profiling["dht_get_time"] = Rcpp::wrap(chem.GetWorkerDHTGetTimings());
|
||||||
|
chem_profiling["dht_fill_time"] =
|
||||||
|
Rcpp::wrap(chem.GetWorkerDHTFillTimings());
|
||||||
}
|
}
|
||||||
|
|
||||||
static void getControlCellIds(const vector<uint32_t> &ids, int root,
|
if (params.use_interp) {
|
||||||
MPI_Comm comm) {
|
chem_profiling["interp_w"] =
|
||||||
std::uint32_t n_ids = 0;
|
Rcpp::wrap(chem.GetWorkerInterpolationWriteTimings());
|
||||||
int rank;
|
chem_profiling["interp_r"] =
|
||||||
MPI_Comm_rank(comm, &rank);
|
Rcpp::wrap(chem.GetWorkerInterpolationReadTimings());
|
||||||
bool is_master = root == rank;
|
chem_profiling["interp_g"] =
|
||||||
|
Rcpp::wrap(chem.GetWorkerInterpolationGatherTimings());
|
||||||
if (is_master) {
|
chem_profiling["interp_fc"] =
|
||||||
n_ids = ids.size();
|
Rcpp::wrap(chem.GetWorkerInterpolationFunctionCallTimings());
|
||||||
}
|
chem_profiling["interp_calls"] =
|
||||||
// broadcast size of id vector
|
Rcpp::wrap(chem.GetWorkerInterpolationCalls());
|
||||||
MPI_Bcast(n_ids, 1, MPI_UINT32_T, root, comm);
|
chem_profiling["interp_cached"] = Rcpp::wrap(chem.GetWorkerPHTCacheHits());
|
||||||
|
|
||||||
// worker
|
|
||||||
if (!is_master) {
|
|
||||||
ids.resize(n_ids);
|
|
||||||
}
|
|
||||||
// broadcast control cell ids
|
|
||||||
if (n_ids > 0) {
|
|
||||||
MPI_Bcast(ids.data(), n_ids, MPI_UINT32_T, root, comm);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Rcpp::List profiling;
|
||||||
|
profiling["simtime"] = dSimTime;
|
||||||
|
profiling["chemistry"] = chem_profiling;
|
||||||
|
profiling["diffusion"] = diffusion_profiling;
|
||||||
|
|
||||||
std::vector<std::string> getSpeciesNames(const Field &&field, int root,
|
chem.MasterLoopBreak();
|
||||||
MPI_Comm comm) {
|
|
||||||
std::uint32_t n_elements;
|
|
||||||
std::uint32_t n_string_size;
|
|
||||||
|
|
||||||
int rank;
|
return profiling;
|
||||||
MPI_Comm_rank(comm, &rank);
|
}
|
||||||
|
|
||||||
const bool is_master = root == rank;
|
static void getControlCellIds(std::vector<std::uint32_t> &ids, int root,
|
||||||
|
MPI_Comm comm) {
|
||||||
|
std::uint32_t n_ids = 0;
|
||||||
|
int rank;
|
||||||
|
MPI_Comm_rank(comm, &rank);
|
||||||
|
bool is_master = root == rank;
|
||||||
|
|
||||||
// first, the master sends all the species names iterative
|
if (is_master) {
|
||||||
if (is_master) {
|
n_ids = ids.size();
|
||||||
n_elements = field.GetProps().size();
|
}
|
||||||
MPI_Bcast(&n_elements, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
|
// broadcast size of id vector
|
||||||
|
MPI_Bcast(&n_ids, 1, MPI_UINT32_T, root, comm);
|
||||||
|
|
||||||
for (std::uint32_t i = 0; i < n_elements; i++) {
|
// worker
|
||||||
n_string_size = field.GetProps()[i].size();
|
if (!is_master) {
|
||||||
MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
|
ids.resize(n_ids);
|
||||||
MPI_Bcast(const_cast<char *>(field.GetProps()[i].c_str()),
|
}
|
||||||
n_string_size, MPI_CHAR, root, MPI_COMM_WORLD);
|
// broadcast control cell ids
|
||||||
}
|
if (n_ids > 0) {
|
||||||
|
MPI_Bcast(ids.data(), n_ids, MPI_UINT32_T, root, comm);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return field.GetProps();
|
std::vector<std::string> getSpeciesNames(const Field &&field, int root,
|
||||||
}
|
MPI_Comm comm) {
|
||||||
|
std::uint32_t n_elements;
|
||||||
|
std::uint32_t n_string_size;
|
||||||
|
|
||||||
// now all the worker stuff
|
int rank;
|
||||||
MPI_Bcast(&n_elements, 1, MPI_UINT32_T, root, comm);
|
MPI_Comm_rank(comm, &rank);
|
||||||
|
|
||||||
std::vector<std::string> species_names_out(n_elements);
|
const bool is_master = root == rank;
|
||||||
|
|
||||||
|
// first, the master sends all the species names iterative
|
||||||
|
if (is_master) {
|
||||||
|
n_elements = field.GetProps().size();
|
||||||
|
MPI_Bcast(&n_elements, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
|
||||||
|
|
||||||
for (std::uint32_t i = 0; i < n_elements; i++) {
|
for (std::uint32_t i = 0; i < n_elements; i++) {
|
||||||
|
n_string_size = field.GetProps()[i].size();
|
||||||
MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
|
MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
|
||||||
|
MPI_Bcast(const_cast<char *>(field.GetProps()[i].c_str()), n_string_size,
|
||||||
char recv_buf[n_string_size];
|
MPI_CHAR, root, MPI_COMM_WORLD);
|
||||||
|
|
||||||
MPI_Bcast(recv_buf, n_string_size, MPI_CHAR, root, MPI_COMM_WORLD);
|
|
||||||
|
|
||||||
species_names_out[i] = std::string(recv_buf, n_string_size);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return species_names_out;
|
return field.GetProps();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::array<double, 2> getBaseTotals(Field && field, int root, MPI_Comm comm) {
|
// now all the worker stuff
|
||||||
std::array<double, 2> base_totals;
|
MPI_Bcast(&n_elements, 1, MPI_UINT32_T, root, comm);
|
||||||
|
|
||||||
int rank;
|
std::vector<std::string> species_names_out(n_elements);
|
||||||
MPI_Comm_rank(comm, &rank);
|
|
||||||
|
|
||||||
const bool is_master = root == rank;
|
for (std::uint32_t i = 0; i < n_elements; i++) {
|
||||||
|
MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
|
||||||
|
|
||||||
if (is_master) {
|
char recv_buf[n_string_size];
|
||||||
const auto h_col = field["H"];
|
|
||||||
const auto o_col = field["O"];
|
|
||||||
|
|
||||||
base_totals[0] = *std::min_element(h_col.begin(), h_col.end());
|
MPI_Bcast(recv_buf, n_string_size, MPI_CHAR, root, MPI_COMM_WORLD);
|
||||||
base_totals[1] = *std::min_element(o_col.begin(), o_col.end());
|
|
||||||
MPI_Bcast(base_totals.data(), 2, MPI_DOUBLE, root, MPI_COMM_WORLD);
|
|
||||||
return base_totals;
|
|
||||||
}
|
|
||||||
|
|
||||||
MPI_Bcast(base_totals.data(), 2, MPI_DOUBLE, root, comm);
|
species_names_out[i] = std::string(recv_buf, n_string_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
return species_names_out;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<double, 2> getBaseTotals(Field &&field, int root, MPI_Comm comm) {
|
||||||
|
std::array<double, 2> base_totals;
|
||||||
|
|
||||||
|
int rank;
|
||||||
|
MPI_Comm_rank(comm, &rank);
|
||||||
|
|
||||||
|
const bool is_master = root == rank;
|
||||||
|
|
||||||
|
if (is_master) {
|
||||||
|
const auto h_col = field["H"];
|
||||||
|
const auto o_col = field["O"];
|
||||||
|
|
||||||
|
base_totals[0] = *std::min_element(h_col.begin(), h_col.end());
|
||||||
|
base_totals[1] = *std::min_element(o_col.begin(), o_col.end());
|
||||||
|
MPI_Bcast(base_totals.data(), 2, MPI_DOUBLE, root, MPI_COMM_WORLD);
|
||||||
return base_totals;
|
return base_totals;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool getHasID(Field && field, int root, MPI_Comm comm) {
|
MPI_Bcast(base_totals.data(), 2, MPI_DOUBLE, root, comm);
|
||||||
bool has_id;
|
|
||||||
|
|
||||||
int rank;
|
return base_totals;
|
||||||
MPI_Comm_rank(comm, &rank);
|
}
|
||||||
|
|
||||||
const bool is_master = root == rank;
|
bool getHasID(Field &&field, int root, MPI_Comm comm) {
|
||||||
|
bool has_id;
|
||||||
|
|
||||||
if (is_master) {
|
int rank;
|
||||||
const auto ID_field = field["ID"];
|
MPI_Comm_rank(comm, &rank);
|
||||||
|
|
||||||
std::set<double> unique_IDs(ID_field.begin(), ID_field.end());
|
const bool is_master = root == rank;
|
||||||
|
|
||||||
has_id = unique_IDs.size() > 1;
|
if (is_master) {
|
||||||
|
const auto ID_field = field["ID"];
|
||||||
|
|
||||||
MPI_Bcast(&has_id, 1, MPI_C_BOOL, root, MPI_COMM_WORLD);
|
std::set<double> unique_IDs(ID_field.begin(), ID_field.end());
|
||||||
|
|
||||||
return has_id;
|
has_id = unique_IDs.size() > 1;
|
||||||
}
|
|
||||||
|
|
||||||
MPI_Bcast(&has_id, 1, MPI_C_BOOL, root, comm);
|
MPI_Bcast(&has_id, 1, MPI_C_BOOL, root, MPI_COMM_WORLD);
|
||||||
|
|
||||||
return has_id;
|
return has_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char *argv[]) {
|
MPI_Bcast(&has_id, 1, MPI_C_BOOL, root, comm);
|
||||||
int world_size;
|
|
||||||
|
|
||||||
MPI_Init(&argc, &argv);
|
return has_id;
|
||||||
|
}
|
||||||
|
|
||||||
{
|
int main(int argc, char *argv[]) {
|
||||||
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
|
int world_size;
|
||||||
MPI_Comm_rank(MPI_COMM_WORLD, &MY_RANK);
|
|
||||||
|
|
||||||
RInsidePOET &R = RInsidePOET::getInstance();
|
MPI_Init(&argc, &argv);
|
||||||
|
|
||||||
if (MY_RANK == 0) {
|
{
|
||||||
MSG("Running POET version " + std::string(poet_version));
|
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
|
||||||
}
|
MPI_Comm_rank(MPI_COMM_WORLD, &MY_RANK);
|
||||||
|
|
||||||
init_global_functions(R);
|
RInsidePOET &R = RInsidePOET::getInstance();
|
||||||
|
|
||||||
RuntimeParameters run_params;
|
|
||||||
|
|
||||||
if (parseInitValues(argc, argv, run_params) != 0) {
|
|
||||||
MPI_Finalize();
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// switch (parseInitValues(argc, argv, run_params)) {
|
|
||||||
// case ParseRet::PARSER_ERROR:
|
|
||||||
// case ParseRet::PARSER_HELP:
|
|
||||||
// MPI_Finalize();
|
|
||||||
// return 0;
|
|
||||||
// case ParseRet::PARSER_OK:
|
|
||||||
// break;
|
|
||||||
// }
|
|
||||||
|
|
||||||
InitialList init_list(R);
|
|
||||||
init_list.importList(run_params.init_params, MY_RANK != 0);
|
|
||||||
|
|
||||||
MSG("RInside initialized on process " + std::to_string(MY_RANK));
|
|
||||||
|
|
||||||
std::cout << std::flush;
|
|
||||||
|
|
||||||
MPI_Barrier(MPI_COMM_WORLD);
|
|
||||||
|
|
||||||
ChemistryModule chemistry(run_params.work_package_size,
|
|
||||||
init_list.getChemistryInit(), MPI_COMM_WORLD);
|
|
||||||
ControlModule control;
|
|
||||||
chemistry.SetControlModule(&control);
|
|
||||||
control.setChemistryModule(&chemistry);
|
|
||||||
|
|
||||||
const ChemistryModule::SurrogateSetup surr_setup = {
|
|
||||||
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
|
||||||
getBaseTotals(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
|
||||||
getHasID(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
|
||||||
run_params.use_dht,
|
|
||||||
run_params.dht_size,
|
|
||||||
run_params.dht_snaps,
|
|
||||||
run_params.out_dir,
|
|
||||||
run_params.use_interp,
|
|
||||||
run_params.interp_bucket_entries,
|
|
||||||
run_params.interp_size,
|
|
||||||
run_params.interp_min_entries,
|
|
||||||
run_params.use_ai_surrogate};
|
|
||||||
|
|
||||||
chemistry.masterEnableSurrogates(surr_setup);
|
|
||||||
|
|
||||||
const ControlModule::ControlSetup ctrl_setup = {
|
|
||||||
run_params.out_dir, // added
|
|
||||||
run_params.checkpoint_interval, run_params.control_interval,
|
|
||||||
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
|
||||||
run_params.mape_threshold};
|
|
||||||
|
|
||||||
control.enableControlLogic(ctrl_setup);
|
|
||||||
|
|
||||||
if (MY_RANK > 0) {
|
|
||||||
chemistry.WorkerLoop();
|
|
||||||
} else {
|
|
||||||
// R.parseEvalQ("mysetup <- setup");
|
|
||||||
// // if (MY_RANK == 0) { // get timestep vector from
|
|
||||||
// // grid_init function ... //
|
|
||||||
|
|
||||||
*global_rt_setup = master_init_R(*global_rt_setup, run_params.out_dir,
|
|
||||||
init_list.getInitialGrid().asSEXP());
|
|
||||||
|
|
||||||
// MDL: store all parameters
|
|
||||||
// MSG("Calling R Function to store calling parameters");
|
|
||||||
// R.parseEvalQ("StoreSetup(setup=mysetup)");
|
|
||||||
R["out_ext"] = run_params.out_ext;
|
|
||||||
R["out_dir"] = run_params.out_dir;
|
|
||||||
|
|
||||||
if (run_params.use_ai_surrogate) {
|
|
||||||
/* Incorporate ai surrogate from R */
|
|
||||||
R.parseEvalQ(ai_surrogate_r_library);
|
|
||||||
/* Use dht species for model input and output */
|
|
||||||
R["ai_surrogate_species"] =
|
|
||||||
init_list.getChemistryInit().dht_species.getNames();
|
|
||||||
|
|
||||||
const std::string ai_surrogate_input_script =
|
|
||||||
init_list.getChemistryInit().ai_surrogate_input_script;
|
|
||||||
|
|
||||||
MSG("AI: sourcing user-provided script");
|
|
||||||
R.parseEvalQ(ai_surrogate_input_script);
|
|
||||||
|
|
||||||
MSG("AI: initialize AI model");
|
|
||||||
R.parseEval("model <- initiate_model()");
|
|
||||||
R.parseEval("gpu_info()");
|
|
||||||
}
|
|
||||||
|
|
||||||
MSG("Init done on process with rank " + std::to_string(MY_RANK));
|
|
||||||
|
|
||||||
// MPI_Barrier(MPI_COMM_WORLD);
|
|
||||||
|
|
||||||
DiffusionModule diffusion(init_list.getDiffusionInit(),
|
|
||||||
init_list.getInitialGrid());
|
|
||||||
|
|
||||||
chemistry.masterSetField(init_list.getInitialGrid());
|
|
||||||
|
|
||||||
Rcpp::List profiling =
|
|
||||||
RunMasterLoop(R, run_params, diffusion, chemistry, control);
|
|
||||||
|
|
||||||
MSG("finished simulation loop");
|
|
||||||
|
|
||||||
R["profiling"] = profiling;
|
|
||||||
R["setup"] = *global_rt_setup;
|
|
||||||
R["setup$out_ext"] = run_params.out_ext;
|
|
||||||
|
|
||||||
std::string r_vis_code;
|
|
||||||
r_vis_code = "SaveRObj(x = profiling, path = paste0(out_dir, "
|
|
||||||
"'/timings.', setup$out_ext));";
|
|
||||||
R.parseEval(r_vis_code);
|
|
||||||
|
|
||||||
MSG("Done! Results are stored as R objects into <" +
|
|
||||||
run_params.out_dir + "/timings." + run_params.out_ext);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
MSG("finished, cleanup of process " + std::to_string(MY_RANK));
|
|
||||||
|
|
||||||
MPI_Finalize();
|
|
||||||
|
|
||||||
if (MY_RANK == 0) {
|
if (MY_RANK == 0) {
|
||||||
MSG("done, bye!");
|
MSG("Running POET version " + std::string(poet_version));
|
||||||
}
|
}
|
||||||
|
|
||||||
exit(0);
|
init_global_functions(R);
|
||||||
|
|
||||||
|
RuntimeParameters run_params;
|
||||||
|
|
||||||
|
if (parseInitValues(argc, argv, run_params) != 0) {
|
||||||
|
MPI_Finalize();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// switch (parseInitValues(argc, argv, run_params)) {
|
||||||
|
// case ParseRet::PARSER_ERROR:
|
||||||
|
// case ParseRet::PARSER_HELP:
|
||||||
|
// MPI_Finalize();
|
||||||
|
// return 0;
|
||||||
|
// case ParseRet::PARSER_OK:
|
||||||
|
// break;
|
||||||
|
// }
|
||||||
|
|
||||||
|
InitialList init_list(R);
|
||||||
|
init_list.importList(run_params.init_params, MY_RANK != 0);
|
||||||
|
|
||||||
|
MSG("RInside initialized on process " + std::to_string(MY_RANK));
|
||||||
|
|
||||||
|
std::cout << std::flush;
|
||||||
|
|
||||||
|
MPI_Barrier(MPI_COMM_WORLD);
|
||||||
|
|
||||||
|
ChemistryModule chemistry(run_params.work_package_size,
|
||||||
|
init_list.getChemistryInit(), MPI_COMM_WORLD);
|
||||||
|
ControlModule control;
|
||||||
|
chemistry.SetControlModule(&control);
|
||||||
|
control.setChemistryModule(&chemistry);
|
||||||
|
|
||||||
|
const ChemistryModule::SurrogateSetup surr_setup = {
|
||||||
|
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
||||||
|
getBaseTotals(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
||||||
|
getHasID(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
||||||
|
run_params.use_dht,
|
||||||
|
run_params.dht_size,
|
||||||
|
run_params.dht_snaps,
|
||||||
|
run_params.out_dir,
|
||||||
|
run_params.use_interp,
|
||||||
|
run_params.interp_bucket_entries,
|
||||||
|
run_params.interp_size,
|
||||||
|
run_params.interp_min_entries,
|
||||||
|
run_params.use_ai_surrogate};
|
||||||
|
|
||||||
|
chemistry.masterEnableSurrogates(surr_setup);
|
||||||
|
|
||||||
|
/* broadcast control cell ids before simulation starts */
|
||||||
|
getControlCellIds(run_params.ctrl_cell_ids, 0, MPI_COMM_WORLD);
|
||||||
|
chemistry.SetControlCellIds(run_params.ctrl_cell_ids);
|
||||||
|
|
||||||
|
const ControlModule::ControlSetup ctrl_setup = {
|
||||||
|
run_params.out_dir, // added
|
||||||
|
run_params.checkpoint_interval,
|
||||||
|
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
||||||
|
run_params.mape_threshold};
|
||||||
|
|
||||||
|
control.enableControlLogic(ctrl_setup);
|
||||||
|
|
||||||
|
if (MY_RANK > 0) {
|
||||||
|
chemistry.WorkerLoop();
|
||||||
|
} else {
|
||||||
|
// R.parseEvalQ("mysetup <- setup");
|
||||||
|
// // if (MY_RANK == 0) { // get timestep vector from
|
||||||
|
// // grid_init function ... //
|
||||||
|
|
||||||
|
*global_rt_setup = master_init_R(*global_rt_setup, run_params.out_dir,
|
||||||
|
init_list.getInitialGrid().asSEXP());
|
||||||
|
|
||||||
|
// MDL: store all parameters
|
||||||
|
// MSG("Calling R Function to store calling parameters");
|
||||||
|
// R.parseEvalQ("StoreSetup(setup=mysetup)");
|
||||||
|
R["out_ext"] = run_params.out_ext;
|
||||||
|
R["out_dir"] = run_params.out_dir;
|
||||||
|
|
||||||
|
if (run_params.use_ai_surrogate) {
|
||||||
|
/* Incorporate ai surrogate from R */
|
||||||
|
R.parseEvalQ(ai_surrogate_r_library);
|
||||||
|
/* Use dht species for model input and output */
|
||||||
|
R["ai_surrogate_species"] =
|
||||||
|
init_list.getChemistryInit().dht_species.getNames();
|
||||||
|
|
||||||
|
const std::string ai_surrogate_input_script =
|
||||||
|
init_list.getChemistryInit().ai_surrogate_input_script;
|
||||||
|
|
||||||
|
MSG("AI: sourcing user-provided script");
|
||||||
|
R.parseEvalQ(ai_surrogate_input_script);
|
||||||
|
|
||||||
|
MSG("AI: initialize AI model");
|
||||||
|
R.parseEval("model <- initiate_model()");
|
||||||
|
R.parseEval("gpu_info()");
|
||||||
|
}
|
||||||
|
|
||||||
|
MSG("Init done on process with rank " + std::to_string(MY_RANK));
|
||||||
|
|
||||||
|
// MPI_Barrier(MPI_COMM_WORLD);
|
||||||
|
|
||||||
|
DiffusionModule diffusion(init_list.getDiffusionInit(),
|
||||||
|
init_list.getInitialGrid());
|
||||||
|
|
||||||
|
chemistry.masterSetField(init_list.getInitialGrid());
|
||||||
|
|
||||||
|
Rcpp::List profiling =
|
||||||
|
RunMasterLoop(R, run_params, diffusion, chemistry, control);
|
||||||
|
|
||||||
|
MSG("finished simulation loop");
|
||||||
|
|
||||||
|
R["profiling"] = profiling;
|
||||||
|
R["setup"] = *global_rt_setup;
|
||||||
|
R["setup$out_ext"] = run_params.out_ext;
|
||||||
|
|
||||||
|
std::string r_vis_code;
|
||||||
|
r_vis_code = "SaveRObj(x = profiling, path = paste0(out_dir, "
|
||||||
|
"'/timings.', setup$out_ext));";
|
||||||
|
R.parseEval(r_vis_code);
|
||||||
|
|
||||||
|
MSG("Done! Results are stored as R objects into <" + run_params.out_dir +
|
||||||
|
"/timings." + run_params.out_ext);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MSG("finished, cleanup of process " + std::to_string(MY_RANK));
|
||||||
|
|
||||||
|
MPI_Finalize();
|
||||||
|
|
||||||
|
if (MY_RANK == 0) {
|
||||||
|
MSG("done, bye!");
|
||||||
|
}
|
||||||
|
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
|||||||
@ -54,7 +54,7 @@ struct RuntimeParameters {
|
|||||||
std::uint32_t checkpoint_interval = 0;
|
std::uint32_t checkpoint_interval = 0;
|
||||||
std::uint32_t control_interval = 0;
|
std::uint32_t control_interval = 0;
|
||||||
std::vector<double> mape_threshold;
|
std::vector<double> mape_threshold;
|
||||||
std::vector<double> ctrl_cell_ids;
|
std::vector<uint32t_t> ctrl_cell_ids;
|
||||||
|
|
||||||
static constexpr std::uint32_t WORK_PACKAGE_SIZE_DEFAULT = 32;
|
static constexpr std::uint32_t WORK_PACKAGE_SIZE_DEFAULT = 32;
|
||||||
std::uint32_t work_package_size = WORK_PACKAGE_SIZE_DEFAULT;
|
std::uint32_t work_package_size = WORK_PACKAGE_SIZE_DEFAULT;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user