mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-16 12:54:50 +01:00
Refractor control module for per-cell error tracking and stabilization phase logic
This commit is contained in:
parent
168cd64770
commit
39458561ff
135
bin/dolo_fgcs_3.R
Normal file
135
bin/dolo_fgcs_3.R
Normal file
@ -0,0 +1,135 @@
|
||||
rows <- 400
|
||||
cols <- 400
|
||||
|
||||
grid_def <- matrix(2, nrow = rows, ncol = cols)
|
||||
|
||||
# Define grid configuration for POET model
|
||||
grid_setup <- list(
|
||||
pqc_in_file = "./dolo_fgcs.pqi",
|
||||
pqc_db_file = "./phreeqc_kin.dat", # Path to the database file for Phreeqc
|
||||
grid_def = grid_def, # Definition of the grid, containing IDs according to the Phreeqc input script
|
||||
grid_size = c(5, 5), # Size of the grid in meters
|
||||
constant_cells = c() # IDs of cells with constant concentration
|
||||
)
|
||||
|
||||
bound_def_we <- list(
|
||||
"type" = rep("constant", rows),
|
||||
"sol_id" = rep(1, rows),
|
||||
"cell" = seq(1, rows)
|
||||
)
|
||||
|
||||
bound_def_ns <- list(
|
||||
"type" = rep("constant", cols),
|
||||
"sol_id" = rep(1, cols),
|
||||
"cell" = seq(1, cols)
|
||||
)
|
||||
|
||||
diffusion_setup <- list(
|
||||
boundaries = list(
|
||||
"W" = bound_def_we,
|
||||
"E" = bound_def_we,
|
||||
"N" = bound_def_ns,
|
||||
"S" = bound_def_ns
|
||||
),
|
||||
inner_boundaries = list(
|
||||
"row" = floor(rows / 2),
|
||||
"col" = floor(cols / 2),
|
||||
"sol_id" = c(3)
|
||||
),
|
||||
alpha_x = 1e-6,
|
||||
alpha_y = 1e-6
|
||||
)
|
||||
|
||||
check_sign_cal_dol_dht <- function(old, new) {
|
||||
if ((old["Calcite"] == 0) != (new["Calcite"] == 0)) {
|
||||
return(TRUE)
|
||||
}
|
||||
if ((old["Dolomite"] == 0) != (new["Dolomite"] == 0)) {
|
||||
return(TRUE)
|
||||
}
|
||||
return(FALSE)
|
||||
}
|
||||
|
||||
check_sign_cal_dol_interp <- function(to_interp, data_set) {
|
||||
dht_species <- c(
|
||||
"H" = 3,
|
||||
"O" = 3,
|
||||
"C" = 6,
|
||||
"Ca" = 6,
|
||||
"Cl" = 3,
|
||||
"Mg" = 5,
|
||||
"Calcite" = 4,
|
||||
"Dolomite" = 4
|
||||
)
|
||||
data_set <- as.data.frame(do.call(rbind, data_set), check.names = FALSE, optional = TRUE)
|
||||
names(data_set) <- names(dht_species)
|
||||
cal <- (data_set$Calcite == 0) == (to_interp["Calcite"] == 0)
|
||||
dol <- (data_set$Dolomite == 0) == (to_interp["Dolomite"] == 0)
|
||||
|
||||
cal_dol_same_sig <- cal == dol
|
||||
return(rev(which(!cal_dol_same_sig)))
|
||||
}
|
||||
|
||||
check_neg_cal_dol <- function(result) {
|
||||
neg_sign <- (result["Calcite"] < 0) || (result["Dolomite"] < 0)
|
||||
return(neg_sign)
|
||||
}
|
||||
|
||||
# Optional when using Interpolation (example with less key species and custom
|
||||
# significant digits)
|
||||
|
||||
pht_species <- c(
|
||||
"C" = 3,
|
||||
"Ca" = 3,
|
||||
"Mg" = 3,
|
||||
"Cl" = 3,
|
||||
"Calcite" = 3,
|
||||
"Dolomite" = 3
|
||||
)
|
||||
|
||||
|
||||
dht_species <- c(
|
||||
"H" = 3,
|
||||
"O" = 3,
|
||||
"C" = 6,
|
||||
"Ca" = 6,
|
||||
"Cl" = 3,
|
||||
"Mg" = 5,
|
||||
"Calcite" = 4,
|
||||
"Dolomite" = 4)
|
||||
|
||||
chemistry_setup <- list(
|
||||
dht_species = dht_species,
|
||||
pht_species = pht_species,
|
||||
hooks = list(
|
||||
dht_fill = check_sign_cal_dol_dht,
|
||||
interp_pre = check_sign_cal_dol_interp,
|
||||
interp_post = check_neg_cal_dol
|
||||
)
|
||||
)
|
||||
|
||||
# Define a setup list for simulation configuration
|
||||
setup <- list(
|
||||
Grid = grid_setup, # Parameters related to the grid structure
|
||||
Diffusion = diffusion_setup, # Parameters related to the diffusion process
|
||||
Chemistry = chemistry_setup # Parameters related to the chemistry process
|
||||
)
|
||||
|
||||
iterations <- 15000
|
||||
dt <- 200
|
||||
checkpoint_interval <- 100
|
||||
control_interval <- 100
|
||||
mape_threshold <- rep(0.1, 13)
|
||||
mape_threshold[5] <- 1 #Charge
|
||||
out_save <- seq(1000, iterations, by = 1000)
|
||||
#out_save = c(seq(1, 10), seq(10, 100, by= 10), seq(200, iterations, by=100))
|
||||
|
||||
|
||||
list(
|
||||
timesteps = rep(dt, iterations),
|
||||
store_result = TRUE,
|
||||
out_save = out_save,
|
||||
checkpoint_interval = checkpoint_interval,
|
||||
control_interval = control_interval,
|
||||
mape_threshold = mape_threshold
|
||||
)
|
||||
BIN
bin/dolo_fgcs_3.qs2
Normal file
BIN
bin/dolo_fgcs_3.qs2
Normal file
Binary file not shown.
1307
bin/phreeqc_kin.dat
Normal file
1307
bin/phreeqc_kin.dat
Normal file
File diff suppressed because it is too large
Load Diff
19
bin/run_poet.sh
Normal file
19
bin/run_poet.sh
Normal file
@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=dolo_proto1_eps01_no_zeroabs
|
||||
#SBATCH --output=dolo_proto1_eps01_no_zeroabs_%j.out
|
||||
#SBATCH --error=dolo_proto1_eps01_no_zeroabs_%j.err
|
||||
#SBATCH --partition=long
|
||||
#SBATCH --nodes=6
|
||||
#SBATCH --ntasks-per-node=24
|
||||
#SBATCH --ntasks=144
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --time=3-00:00:00
|
||||
|
||||
|
||||
source /etc/profile.d/modules.sh
|
||||
module purge
|
||||
module load cmake gcc openmpi
|
||||
|
||||
#mpirun -n 144 ./poet dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_only_pqc
|
||||
mpirun -n 144 ./poet --interp dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 dolo_proto1_eps01_no_zeroabs
|
||||
#mpirun -n 144 ./poet --interp barite_fgcs_4_new/barite_fgcs_4_new_rt.R barite_fgcs_4_new/barite_fgcs_4_new.qs2 barite
|
||||
Binary file not shown.
Binary file not shown.
@ -102,6 +102,7 @@ public:
|
||||
this->ai_surrogate_enabled = setup.ai_surrogate_enabled;
|
||||
|
||||
this->base_totals = setup.base_totals;
|
||||
this->ctrl_file_out_dir = setup.dht_out_dir;
|
||||
|
||||
if (this->dht_enabled || this->interp_enabled) {
|
||||
this->initializeDHT(setup.dht_size_mb, this->params.dht_species,
|
||||
@ -258,7 +259,7 @@ public:
|
||||
|
||||
std::vector<int> ai_surrogate_validity_vector;
|
||||
|
||||
void SetControlModule(poet::ControlModule *ctrl) { control_module = ctrl; }
|
||||
void SetControlModule(poet::ControlModule *ctrl) { control = ctrl; }
|
||||
|
||||
void SetDhtEnabled(bool enabled) { this->dht_enabled = enabled; }
|
||||
bool GetDhtEnabled() const { return this->dht_enabled; }
|
||||
@ -266,7 +267,8 @@ public:
|
||||
void SetInterpEnabled(bool enabled) { this->interp_enabled = enabled; }
|
||||
bool GetInterpEnabled() const { return interp_enabled; }
|
||||
|
||||
void SetWarmupEnabled(bool enabled) { this->warmup_enabled = enabled; }
|
||||
void SetStabEnabled(bool enabled) { this->stab_enabled = enabled; }
|
||||
bool GetStabEnabled() const { return stab_enabled; }
|
||||
|
||||
void SetControlCellIds(const std::vector<uint32_t> &ids) {
|
||||
this->ctrl_cell_ids = std::unordered_set<uint32_t>(ids.begin(), ids.end());
|
||||
@ -285,13 +287,13 @@ protected:
|
||||
|
||||
enum {
|
||||
CHEM_FIELD_INIT,
|
||||
// CHEM_DHT_ENABLE,
|
||||
CHEM_DHT_ENABLE,
|
||||
CHEM_IP_ENABLE,
|
||||
CHEM_CTRL_ENABLE,
|
||||
CHEM_CTRL_FLAGS,
|
||||
CHEM_DHT_SIGNIF_VEC,
|
||||
CHEM_DHT_SNAPS,
|
||||
CHEM_DHT_READ_FILE,
|
||||
// CHEM_WARMUP_PHASE, // Control flag
|
||||
// CHEM_CTRL_ENABLE, // Control flag
|
||||
// CHEM_IP_ENABLE,
|
||||
CHEM_IP_MIN_ENTRIES,
|
||||
CHEM_IP_SIGNIF_VEC,
|
||||
CHEM_WORK_LOOP,
|
||||
@ -300,6 +302,9 @@ protected:
|
||||
CHEM_AI_BCAST_VALIDITY
|
||||
};
|
||||
|
||||
/* broadcasted only every control iteration */
|
||||
enum { DHT_ENABLE = 1u << 0, IP_ENABLE = 1u << 1, STAB_ENABLE = 1u << 2 };
|
||||
|
||||
enum { LOOP_WORK, LOOP_END, LOOP_CTRL };
|
||||
|
||||
enum {
|
||||
@ -358,7 +363,7 @@ protected:
|
||||
|
||||
void WorkerDoWork(MPI_Status &probe_status, int double_count,
|
||||
struct worker_s &timings);
|
||||
void WorkerPostIter(MPI_Status &prope_status, uint32_t iteration);
|
||||
void WorkerPostIter(MPI_Status &probe_status, uint32_t iteration);
|
||||
void WorkerPostSim(uint32_t iteration);
|
||||
|
||||
void WorkerWriteDHTDump(uint32_t iteration);
|
||||
@ -370,10 +375,6 @@ protected:
|
||||
void WorkerRunWorkPackage(WorkPackage &work_package, double dSimTime,
|
||||
double dTimestep);
|
||||
|
||||
void ProcessControlWorkPackage(std::vector<std::vector<double>> &input,
|
||||
double current_sim_time, double dt,
|
||||
struct worker_s &timings);
|
||||
|
||||
std::vector<uint32_t> CalculateWPSizesVector(uint32_t n_cells,
|
||||
uint32_t wp_size) const;
|
||||
std::vector<double> shuffleField(const std::vector<double> &in_field,
|
||||
@ -388,9 +389,25 @@ protected:
|
||||
|
||||
void BCastStringVec(std::vector<std::string> &io);
|
||||
|
||||
int packResultsIntoBuffer(std::vector<double> &mpi_buffer, int base_count,
|
||||
const WorkPackage &wp,
|
||||
const WorkPackage &wp_control);
|
||||
void processCtrlPkgs(std::vector<std::vector<double>> &input,
|
||||
double current_sim_time, double dt,
|
||||
struct worker_s &timings);
|
||||
|
||||
void copyPkgs(const WorkPackage &wp, std::vector<double> &mpi_buffer);
|
||||
|
||||
inline int buildCtrlFlags(bool dht, bool interp, bool stab) {
|
||||
int flags = 0;
|
||||
|
||||
if (dht)
|
||||
flags |= DHT_ENABLE;
|
||||
if (interp)
|
||||
flags |= IP_ENABLE;
|
||||
if (stab)
|
||||
flags |= STAB_ENABLE;
|
||||
return flags;
|
||||
}
|
||||
|
||||
inline bool hasFlag(int flags, int type) { return (flags & type) != 0; }
|
||||
|
||||
int comm_size, comm_rank;
|
||||
MPI_Comm group_comm;
|
||||
@ -410,7 +427,7 @@ protected:
|
||||
|
||||
bool ai_surrogate_enabled{false};
|
||||
|
||||
static constexpr uint32_t BUFFER_OFFSET = 6;
|
||||
static constexpr uint32_t BUFFER_OFFSET = 5;
|
||||
|
||||
inline void ChemBCast(void *buf, int count, MPI_Datatype datatype) const {
|
||||
MPI_Bcast(buf, count, datatype, 0, this->group_comm);
|
||||
@ -447,13 +464,14 @@ protected:
|
||||
|
||||
std::unique_ptr<PhreeqcRunner> pqc_runner;
|
||||
|
||||
poet::ControlModule *control_module = nullptr;
|
||||
std::string ctrl_file_out_dir;
|
||||
|
||||
poet::ControlModule *control = nullptr;
|
||||
|
||||
std::vector<double> mpi_surr_buffer;
|
||||
|
||||
bool control_enabled{false};
|
||||
bool warmup_enabled{false};
|
||||
|
||||
bool stab_enabled{false};
|
||||
std::unordered_set<uint32_t> ctrl_cell_ids;
|
||||
std::vector<std::vector<double>> control_batch;
|
||||
};
|
||||
|
||||
@ -280,11 +280,13 @@ inline void poet::ChemistryModule::MasterSendPkgs(
|
||||
// current work package start location in field
|
||||
send_buffer[end_of_wp + 4] = wp_start_index;
|
||||
// control flags (bitmask)
|
||||
|
||||
/*
|
||||
int flags = (this->interp_enabled ? 1 : 0) | (this->dht_enabled ? 2 : 0) |
|
||||
(this->warmup_enabled ? 4 : 0) |
|
||||
(this->control_enabled ? 8 : 0);
|
||||
send_buffer[end_of_wp + 5] = static_cast<double>(flags);
|
||||
|
||||
*/
|
||||
/* ATTENTION Worker p has rank p+1 */
|
||||
// MPI_Send(send_buffer, end_of_wp + BUFFER_OFFSET, MPI_DOUBLE, p + 1,
|
||||
// LOOP_WORK, this->group_comm);
|
||||
@ -441,6 +443,14 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
||||
MPI_INT);
|
||||
}
|
||||
|
||||
if (control->shouldBcastFlags()) {
|
||||
int ftype = CHEM_CTRL_FLAGS;
|
||||
PropagateFunctionType(ftype);
|
||||
uint32_t ctrl_flags = buildCtrlFlags(
|
||||
this->dht_enabled, this->interp_enabled, this->stab_enabled);
|
||||
ChemBCast(&ctrl_flags, 1, MPI_INT);
|
||||
}
|
||||
|
||||
ftype = CHEM_WORK_LOOP;
|
||||
PropagateFunctionType(ftype);
|
||||
|
||||
@ -512,6 +522,9 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
||||
chem_field = out_vec;
|
||||
|
||||
/* do master stuff */
|
||||
std::cout << "[DEBUG] control_batch.size() = "
|
||||
<< this->control_batch.size() << std::endl;
|
||||
|
||||
if (!this->control_batch.empty()) {
|
||||
std::cout << "[Master] Processing " << this->control_batch.size()
|
||||
<< " control cells for comparison." << std::endl;
|
||||
@ -536,8 +549,11 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
||||
}
|
||||
|
||||
metrics_a = MPI_Wtime();
|
||||
control_module->computeSpeciesErrorMetrics(this->control_batch,
|
||||
surrogate_batch);
|
||||
control->computeErrorMetrics(this->control_batch, surrogate_batch,
|
||||
prop_names);
|
||||
|
||||
control->writeErrorMetrics(ctrl_file_out_dir, prop_names);
|
||||
|
||||
metrics_b = MPI_Wtime();
|
||||
this->metrics_t += metrics_b - metrics_a;
|
||||
|
||||
@ -553,12 +569,20 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
||||
this->seq_t += seq_d - seq_c;
|
||||
|
||||
/* end time measurement of whole chemistry simulation */
|
||||
std::optional<uint32_t> target = control->getRollbackTarget(prop_names);
|
||||
int flush = target.has_value() ? 1 : 0;
|
||||
|
||||
/* advise workers to end chemistry iteration */
|
||||
for (int i = 1; i < this->comm_size; i++) {
|
||||
MPI_Send(NULL, 0, MPI_DOUBLE, i, LOOP_END, this->group_comm);
|
||||
MPI_Send(&flush, 1, MPI_INT, i, LOOP_END, this->group_comm);
|
||||
}
|
||||
|
||||
/*
|
||||
if (flush) {
|
||||
control->clearFlushRequest();
|
||||
}
|
||||
*/
|
||||
|
||||
this->simtime += dt;
|
||||
iteration++;
|
||||
}
|
||||
|
||||
@ -59,6 +59,15 @@ void poet::ChemistryModule::WorkerLoop() {
|
||||
MPI_INT, 0, this->group_comm);
|
||||
break;
|
||||
}
|
||||
case CHEM_CTRL_FLAGS: {
|
||||
int flags = 0;
|
||||
ChemBCast(&flags, 1, MPI_INT);
|
||||
this->dht_enabled = hasFlag(flags, DHT_ENABLE);
|
||||
this->interp_enabled = hasFlag(flags, IP_ENABLE);
|
||||
this->stab_enabled = hasFlag(flags, STAB_ENABLE);
|
||||
break;
|
||||
}
|
||||
|
||||
case CHEM_WORK_LOOP: {
|
||||
WorkerProcessPkgs(timings, iteration);
|
||||
break;
|
||||
@ -116,7 +125,16 @@ void poet::ChemistryModule::WorkerProcessPkgs(struct worker_s &timings,
|
||||
}
|
||||
}
|
||||
|
||||
void poet::ChemistryModule::ProcessControlWorkPackage(
|
||||
void poet::ChemistryModule::copyPkgs(const WorkPackage &wp,
|
||||
std::vector<double> &mpi_buffer) {
|
||||
|
||||
for (std::size_t wp_i = 0; wp_i < wp.size; wp_i++) {
|
||||
std::copy(wp.output[wp_i].begin(), wp.output[wp_i].end(),
|
||||
mpi_buffer.begin() + this->prop_count * wp_i);
|
||||
}
|
||||
}
|
||||
|
||||
void poet::ChemistryModule::processCtrlPkgs(
|
||||
std::vector<std::vector<double>> &input, double current_sim_time, double dt,
|
||||
struct worker_s &timings) {
|
||||
|
||||
@ -137,10 +155,7 @@ void poet::ChemistryModule::ProcessControlWorkPackage(
|
||||
|
||||
timings.ctrl_phreeqc_t += phreeqc_end - phreeqc_start;
|
||||
|
||||
for (std::size_t wp_i = 0; wp_i < control_wp.size; wp_i++) {
|
||||
std::copy(control_wp.output[wp_i].begin(), control_wp.output[wp_i].end(),
|
||||
mpi_buffer.begin() + this->prop_count * wp_i);
|
||||
}
|
||||
copyPkgs(control_wp, mpi_buffer);
|
||||
|
||||
MPI_Request send_req;
|
||||
MPI_Isend(mpi_buffer.data(), mpi_buffer.size(), MPI_DOUBLE, 0, LOOP_CTRL,
|
||||
@ -194,13 +209,16 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
wp_start_index = mpi_buffer[count + 4];
|
||||
|
||||
// read packed control flags
|
||||
/*
|
||||
flags = static_cast<int>(mpi_buffer[count + 5]);
|
||||
this->interp_enabled = (flags & 1) != 0;
|
||||
this->dht_enabled = (flags & 2) != 0;
|
||||
this->warmup_enabled = (flags & 4) != 0;
|
||||
this->control_enabled = (flags & 8) != 0;
|
||||
*/
|
||||
|
||||
/*std::cout << "warmup_enabled is " << warmup_enabled << ", control_enabled is "
|
||||
/*std::cout << "warmup_enabled is " << warmup_enabled << ", control_enabled is
|
||||
"
|
||||
<< control_enabled << ", dht_enabled is "
|
||||
<< dht_enabled << ", interp_enabled is " << interp_enabled
|
||||
<< std::endl;*/
|
||||
@ -212,7 +230,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
}
|
||||
|
||||
// std::cout << this->comm_rank << ":" << counter++ << std::endl;
|
||||
if (dht_enabled || interp_enabled || warmup_enabled) {
|
||||
if (dht_enabled || interp_enabled || stab_enabled) {
|
||||
dht->prepareKeys(s_curr_wp.input, dt);
|
||||
}
|
||||
|
||||
@ -242,17 +260,18 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
|
||||
uint32_t cell_id = s_curr_wp.input[wp_i][0];
|
||||
|
||||
bool is_control_cell = this->ctrl_cell_ids.find(cell_id) != this->ctrl_cell_ids.end();
|
||||
bool used_surrogate = s_curr_wp.mapping[wp_i] != CHEM_PQC;
|
||||
bool is_ctrl_cell =
|
||||
this->ctrl_cell_ids.find(cell_id) != this->ctrl_cell_ids.end();
|
||||
bool used_sur = s_curr_wp.mapping[wp_i] != CHEM_PQC;
|
||||
|
||||
if (is_control_cell && used_surrogate) {
|
||||
if (is_ctrl_cell && used_sur) {
|
||||
|
||||
control_batch.push_back(s_curr_wp.input[wp_i]);
|
||||
control_cells_processed++;
|
||||
|
||||
if (control_batch.size() == s_curr_wp.size ||
|
||||
control_cells_processed == this->ctrl_cell_ids.size()) {
|
||||
ProcessControlWorkPackage(control_batch, current_sim_time, dt, timings);
|
||||
processCtrlPkgs(control_batch, current_sim_time, dt, timings);
|
||||
control_batch.clear();
|
||||
control_cells_processed = 0;
|
||||
}
|
||||
@ -265,23 +284,20 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
|
||||
phreeqc_time_end = MPI_Wtime();
|
||||
|
||||
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
|
||||
std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(),
|
||||
mpi_buffer.begin() + this->prop_count * wp_i);
|
||||
}
|
||||
copyPkgs(s_curr_wp, mpi_buffer);
|
||||
|
||||
/* send results to master */
|
||||
MPI_Request send_req;
|
||||
MPI_Isend(mpi_buffer.data(), count, MPI_DOUBLE, 0, LOOP_WORK, MPI_COMM_WORLD,
|
||||
&send_req);
|
||||
|
||||
if (dht_enabled || interp_enabled || warmup_enabled) {
|
||||
if (dht_enabled || interp_enabled || stab_enabled) {
|
||||
/* write results to DHT */
|
||||
dht_fill_start = MPI_Wtime();
|
||||
dht->fillDHT(s_curr_wp);
|
||||
dht_fill_end = MPI_Wtime();
|
||||
|
||||
if (interp_enabled || warmup_enabled) {
|
||||
if (interp_enabled || stab_enabled) {
|
||||
interp->writePairs();
|
||||
}
|
||||
timings.dht_fill += dht_fill_end - dht_fill_start;
|
||||
@ -291,10 +307,18 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
MPI_Wait(&send_req, MPI_STATUS_IGNORE);
|
||||
}
|
||||
|
||||
void poet::ChemistryModule::WorkerPostIter(MPI_Status &prope_status,
|
||||
void poet::ChemistryModule::WorkerPostIter(MPI_Status &probe_status,
|
||||
uint32_t iteration) {
|
||||
MPI_Recv(NULL, 0, MPI_DOUBLE, 0, LOOP_END, this->group_comm,
|
||||
MPI_STATUS_IGNORE);
|
||||
int size, flush_request = 0;
|
||||
MPI_Get_count(&probe_status, MPI_INT, &size);
|
||||
|
||||
if (size == 1) {
|
||||
MPI_Recv(&flush_request, size, MPI_INT, probe_status.MPI_SOURCE, LOOP_END,
|
||||
this->group_comm, MPI_STATUS_IGNORE);
|
||||
} else {
|
||||
MPI_Recv(NULL, 0, MPI_INT, probe_status.MPI_SOURCE, LOOP_END,
|
||||
this->group_comm, MPI_STATUS_IGNORE);
|
||||
}
|
||||
|
||||
if (this->dht_enabled) {
|
||||
dht_hits.push_back(dht->getHits());
|
||||
@ -320,7 +344,7 @@ void poet::ChemistryModule::WorkerPostIter(MPI_Status &prope_status,
|
||||
const auto max_mean_idx =
|
||||
DHT_get_used_idx_factor(this->interp->getDHTObject(), 1);
|
||||
|
||||
if (max_mean_idx >= 2) {
|
||||
if (max_mean_idx >= 2 || flush_request) {
|
||||
DHT_flush(this->interp->getDHTObject());
|
||||
DHT_flush(this->dht->getDHT());
|
||||
if (this->comm_rank == 2) {
|
||||
|
||||
@ -4,9 +4,13 @@
|
||||
#include "IO/StatsIO.hpp"
|
||||
#include <cmath>
|
||||
|
||||
void poet::ControlModule::updateControlIteration(const uint32_t &iter,
|
||||
const bool &dht_enabled,
|
||||
const bool &interp_enabled) {
|
||||
poet::ControlModule::ControlModule(const ControlConfig &config_)
|
||||
: config(config_) {}
|
||||
|
||||
void poet::ControlModule::beginIteration(ChemistryModule &chem,
|
||||
const uint32_t &iter,
|
||||
const bool &dht_enabled,
|
||||
const bool &interp_enabled) {
|
||||
|
||||
/* dht_enabled and inter_enabled are user settings set before startig the
|
||||
* simulation*/
|
||||
@ -14,81 +18,49 @@ void poet::ControlModule::updateControlIteration(const uint32_t &iter,
|
||||
|
||||
prep_a = MPI_Wtime();
|
||||
|
||||
/*
|
||||
if (control_interval == 0) {
|
||||
control_interval_enabled = false;
|
||||
return;
|
||||
}
|
||||
*/
|
||||
global_iteration = iter;
|
||||
initiateWarmupPhase(dht_enabled, interp_enabled);
|
||||
|
||||
/*
|
||||
control_interval_enabled =
|
||||
(control_interval > 0 && iter % control_interval == 0);
|
||||
|
||||
|
||||
if (control_interval_enabled) {
|
||||
MSG("[Control] Control interval enabled at iteration " +
|
||||
std::to_string(iter));
|
||||
}
|
||||
*/
|
||||
updateStabilizationPhase(chem, dht_enabled, interp_enabled);
|
||||
prep_b = MPI_Wtime();
|
||||
this->prep_t += prep_b - prep_a;
|
||||
}
|
||||
|
||||
void poet::ControlModule::initiateWarmupPhase(bool dht_enabled,
|
||||
bool interp_enabled) {
|
||||
|
||||
// user requested DHT/INTEP? keep them disabled but enable warmup-phase so
|
||||
if (global_iteration < stabilization_interval || rollback_enabled) {
|
||||
chem->SetWarmupEnabled(true);
|
||||
chem->SetDhtEnabled(false);
|
||||
chem->SetInterpEnabled(false);
|
||||
|
||||
MSG("Stabilization enabled until next control interval at iteration " +
|
||||
std::to_string(stabilization_interval) + ".");
|
||||
|
||||
if (sur_disabled_counter > 0) {
|
||||
--sur_disabled_counter;
|
||||
MSG("Rollback counter: " + std::to_string(sur_disabled_counter));
|
||||
void poet::ControlModule::updateStabilizationPhase(ChemistryModule &chem,
|
||||
bool dht_enabled,
|
||||
bool interp_enabled) {
|
||||
if (rollback_enabled) {
|
||||
if (disable_surr_counter > 0) {
|
||||
--disable_surr_counter;
|
||||
flush_request = false;
|
||||
MSG("Rollback counter: " + std::to_string(disable_surr_counter));
|
||||
} else {
|
||||
rollback_enabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
bool prev_stab_state = chem.GetStabEnabled();
|
||||
|
||||
// user requested DHT/INTEP? keep them disabled but enable warmup-phase so
|
||||
if (global_iteration <= config.stab_interval || rollback_enabled) {
|
||||
chem.SetStabEnabled(true);
|
||||
chem.SetDhtEnabled(false);
|
||||
chem.SetInterpEnabled(false);
|
||||
return;
|
||||
}
|
||||
|
||||
chem->SetWarmupEnabled(false);
|
||||
chem->SetDhtEnabled(dht_enabled);
|
||||
chem->SetInterpEnabled(interp_enabled);
|
||||
}
|
||||
chem.SetStabEnabled(false);
|
||||
chem.SetDhtEnabled(dht_enabled);
|
||||
chem.SetInterpEnabled(interp_enabled);
|
||||
|
||||
void poet::ControlModule::applyControlLogic(DiffusionModule &diffusion,
|
||||
uint32_t &iter) {
|
||||
|
||||
/*
|
||||
if (!control_interval_enabled) {
|
||||
return;
|
||||
}
|
||||
*/
|
||||
writeCheckpointAndMetrics(diffusion, iter);
|
||||
|
||||
if (checkAndRollback(diffusion, iter)) {
|
||||
rollback_enabled = true;
|
||||
rollback_count++;
|
||||
sur_disabled_counter = stabilization_interval;
|
||||
|
||||
MSG("Interpolation disabled for the next " +
|
||||
std::to_string(stabilization_interval) + ".");
|
||||
// Mark that we need to broadcast flags if stab phase just ended
|
||||
if (prev_stab_state && !chem.GetStabEnabled()) {
|
||||
stab_phase_ended = true;
|
||||
}
|
||||
}
|
||||
|
||||
void poet::ControlModule::writeCheckpointAndMetrics(DiffusionModule &diffusion,
|
||||
uint32_t iter) {
|
||||
|
||||
double w_check_a, w_check_b, stats_a, stats_b;
|
||||
MSG("Writing checkpoint of iteration " + std::to_string(iter));
|
||||
void poet::ControlModule::writeCheckpoint(DiffusionModule &diffusion,
|
||||
uint32_t &iter,
|
||||
const std::string &out_dir) {
|
||||
double w_check_a, w_check_b;
|
||||
|
||||
w_check_a = MPI_Wtime();
|
||||
write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
|
||||
@ -96,88 +68,107 @@ void poet::ControlModule::writeCheckpointAndMetrics(DiffusionModule &diffusion,
|
||||
w_check_b = MPI_Wtime();
|
||||
this->w_check_t += w_check_b - w_check_a;
|
||||
|
||||
last_checkpoint_written = iter;
|
||||
}
|
||||
|
||||
void poet::ControlModule::readCheckpoint(DiffusionModule &diffusion,
|
||||
uint32_t ¤t_iter,
|
||||
uint32_t rollback_iter,
|
||||
const std::string &out_dir) {
|
||||
double r_check_a, r_check_b;
|
||||
|
||||
r_check_a = MPI_Wtime();
|
||||
Checkpoint_s checkpoint_read{.field = diffusion.getField()};
|
||||
read_checkpoint(out_dir,
|
||||
"checkpoint" + std::to_string(rollback_iter) + ".hdf5",
|
||||
checkpoint_read);
|
||||
current_iter = checkpoint_read.iteration;
|
||||
r_check_b = MPI_Wtime();
|
||||
r_check_t += r_check_b - r_check_a;
|
||||
}
|
||||
|
||||
void poet::ControlModule::writeErrorMetrics(
|
||||
const std::string &out_dir, const std::vector<std::string> &species) {
|
||||
double stats_a, stats_b;
|
||||
|
||||
stats_a = MPI_Wtime();
|
||||
writeStatsToCSV(metricsHistory, species_names, out_dir, "stats_overview");
|
||||
writeStatsToCSV(metrics_history, species, out_dir, "metrics_overview");
|
||||
stats_b = MPI_Wtime();
|
||||
|
||||
this->stats_t += stats_b - stats_a;
|
||||
}
|
||||
|
||||
bool poet::ControlModule::checkAndRollback(DiffusionModule &diffusion,
|
||||
uint32_t &iter) {
|
||||
uint32_t poet::ControlModule::getRollbackIter() {
|
||||
|
||||
uint32_t last_iter = ((global_iteration - 1) / config.checkpoint_interval) *
|
||||
config.checkpoint_interval;
|
||||
|
||||
uint32_t rollback_iter = (last_iter <= last_checkpoint_written)
|
||||
? last_iter
|
||||
: last_checkpoint_written;
|
||||
return rollback_iter;
|
||||
}
|
||||
|
||||
std::optional<uint32_t> poet::ControlModule::getRollbackTarget(
|
||||
const std::vector<std::string> &species) {
|
||||
double r_check_a, r_check_b;
|
||||
|
||||
if (global_iteration < stabilization_interval) {
|
||||
return false;
|
||||
if (metrics_history.empty()) {
|
||||
MSG("No error history yet, skipping rollback check.");
|
||||
rollback_enabled = false;
|
||||
// flush_request = false;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (metricsHistory.empty()) {
|
||||
MSG("No error history yet; skipping rollback check.");
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto &mape = metricsHistory.back().mape;
|
||||
|
||||
const auto &mape = metrics_history.back().mape;
|
||||
for (size_t row = 0; row < mape.size(); row++) {
|
||||
for (size_t col = 0; col < species_names.size() && col < mape[row].size(); col++) {
|
||||
for (size_t col = 0; col < species.size() && col < mape[row].size();
|
||||
col++) {
|
||||
|
||||
if (mape[row][col] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (mape[row][col] > mape_threshold[col]) {
|
||||
uint32_t rollback_iter =
|
||||
((iter - 1) / checkpoint_interval) * checkpoint_interval;
|
||||
if (mape[row][col] > config.mape_threshold[col]) {
|
||||
|
||||
MSG("[THRESHOLD EXCEEDED] " + species_names[col] +
|
||||
" has MAPE = " + std::to_string(mape[row][col]) +
|
||||
" exceeding threshold = " + std::to_string(mape_threshold[col]) +
|
||||
", rolling back to iteration " + std::to_string(rollback_iter));
|
||||
if (last_checkpoint_written == 0) {
|
||||
MSG(" Threshold exceeded but no checkpoint exists yet.");
|
||||
return std::nullopt;
|
||||
}
|
||||
rollback_enabled = true;
|
||||
flush_request = true;
|
||||
|
||||
r_check_a = MPI_Wtime();
|
||||
Checkpoint_s checkpoint_read{.field = diffusion.getField()};
|
||||
read_checkpoint(out_dir,
|
||||
"checkpoint" + std::to_string(rollback_iter) + ".hdf5",
|
||||
checkpoint_read);
|
||||
iter = checkpoint_read.iteration;
|
||||
r_check_b = MPI_Wtime();
|
||||
r_check_t += r_check_b - r_check_a;
|
||||
return true;
|
||||
MSG("Threshold exceeded " + species[col] + " has MAPE = " +
|
||||
std::to_string(mape[row][col]) + " exceeding threshold = " +
|
||||
std::to_string(config.mape_threshold[col]));
|
||||
|
||||
return getRollbackIter();
|
||||
}
|
||||
}
|
||||
}
|
||||
MSG("All species are within their MAPE thresholds.");
|
||||
|
||||
return false;
|
||||
rollback_enabled = false;
|
||||
// flush_request = false;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
void poet::ControlModule::computeSpeciesErrorMetrics(
|
||||
void poet::ControlModule::computeErrorMetrics(
|
||||
std::vector<std::vector<double>> &reference_values,
|
||||
std::vector<std::vector<double>> &surrogate_values) {
|
||||
std::vector<std::vector<double>> &surrogate_values,
|
||||
const std::vector<std::string> &species) {
|
||||
|
||||
const uint32_t num_cells = reference_values.size();
|
||||
const uint32_t species_count = this->species_names.size();
|
||||
const uint32_t n_cells = reference_values.size();
|
||||
|
||||
std::cout << "[DEBUG] computeSpeciesErrorMetrics: num_cells=" << num_cells
|
||||
<< ", species_count=" << species_count << std::endl;
|
||||
|
||||
SpeciesErrorMetrics metrics(num_cells, species_count, global_iteration,
|
||||
SpeciesErrorMetrics metrics(n_cells, species.size(), global_iteration,
|
||||
rollback_count);
|
||||
|
||||
if (reference_values.size() != surrogate_values.size()) {
|
||||
MSG(" Reference and surrogate vectors differ in size: " +
|
||||
std::to_string(reference_values.size()) + " vs " +
|
||||
std::to_string(surrogate_values.size()));
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t cell_i = 0; cell_i < num_cells; cell_i++) {
|
||||
for (size_t cell_i = 0; cell_i < n_cells; cell_i++) {
|
||||
|
||||
metrics.id.push_back(reference_values[cell_i][0]);
|
||||
|
||||
for (size_t sp_i = 0; sp_i < reference_values[cell_i].size(); sp_i++) {
|
||||
for (size_t sp_i = 0; sp_i < species.size(); sp_i++) {
|
||||
const double ref_value = reference_values[cell_i][sp_i];
|
||||
const double sur_value = surrogate_values[cell_i][sp_i];
|
||||
const double ZERO_ABS = 1e-13;
|
||||
const double ZERO_ABS = config.zero_abs;
|
||||
|
||||
if (std::isnan(ref_value) || std::isnan(sur_value)) {
|
||||
metrics.mape[cell_i][sp_i] = 0.0;
|
||||
@ -202,6 +193,43 @@ void poet::ControlModule::computeSpeciesErrorMetrics(
|
||||
}
|
||||
|
||||
std::cout << "[DEBUG] metrics.id.size()=" << metrics.id.size() << std::endl;
|
||||
metricsHistory.push_back(metrics);
|
||||
std::cout << "[DEBUG] metricsHistory.size()=" << metricsHistory.size() << std::endl;
|
||||
metrics_history.push_back(metrics);
|
||||
std::cout << "[DEBUG] metricsHistory.size()=" << metrics_history.size()
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
void poet::ControlModule::processCheckpoint(
|
||||
DiffusionModule &diffusion, uint32_t ¤t_iter,
|
||||
const std::string &out_dir, const std::vector<std::string> &species) {
|
||||
|
||||
if (flush_request) {
|
||||
uint32_t target = getRollbackIter();
|
||||
readCheckpoint(diffusion, current_iter, target, out_dir);
|
||||
|
||||
rollback_enabled = true;
|
||||
rollback_count++;
|
||||
disable_surr_counter = config.stab_interval;
|
||||
|
||||
MSG("Restored checkpoint " + std::to_string(target) +
|
||||
", surrogates disabled for " + std::to_string(config.stab_interval));
|
||||
} else {
|
||||
writeCheckpoint(diffusion, global_iteration, out_dir);
|
||||
}
|
||||
}
|
||||
|
||||
bool poet::ControlModule::shouldBcastFlags() {
|
||||
if (global_iteration == 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (stab_phase_ended) {
|
||||
stab_phase_ended = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (flush_request) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
@ -7,6 +7,7 @@
|
||||
#include "poet.hpp"
|
||||
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -15,107 +16,105 @@ namespace poet {
|
||||
class ChemistryModule;
|
||||
class DiffusionModule;
|
||||
|
||||
struct ControlConfig {
|
||||
uint32_t stab_interval = 0;
|
||||
uint32_t checkpoint_interval = 0;
|
||||
double zero_abs = 0.0;
|
||||
std::vector<double> mape_threshold;
|
||||
};
|
||||
|
||||
struct SpeciesErrorMetrics {
|
||||
std::vector<std::uint32_t> id;
|
||||
std::vector<std::vector<double>> mape;
|
||||
std::vector<std::vector<double>> rrmse;
|
||||
uint32_t iteration = 0;
|
||||
uint32_t rollback_count = 0;
|
||||
|
||||
SpeciesErrorMetrics(uint32_t n_cells, uint32_t n_species, uint32_t iter,
|
||||
uint32_t rb_count)
|
||||
: mape(n_cells, std::vector<double>(n_species, 0.0)),
|
||||
rrmse(n_cells, std::vector<double>(n_species, 0.0)), iteration(iter),
|
||||
rollback_count(rb_count) {}
|
||||
};
|
||||
|
||||
class ControlModule {
|
||||
|
||||
public:
|
||||
/* Control configuration*/
|
||||
explicit ControlModule(const ControlConfig &config);
|
||||
|
||||
// std::uint32_t global_iter = 0;
|
||||
// std::uint32_t sur_disabled_counter = 0;
|
||||
// std::uint32_t rollback_counter = 0;
|
||||
void beginIteration(ChemistryModule &chem, const uint32_t &iter,
|
||||
const bool &dht_enabled, const bool &interp_enaled);
|
||||
|
||||
void updateControlIteration(const uint32_t &iter, const bool &dht_enabled,
|
||||
const bool &interp_enaled);
|
||||
void writeErrorMetrics(const std::string &out_dir,
|
||||
const std::vector<std::string> &species);
|
||||
|
||||
void initiateWarmupPhase(bool dht_enabled, bool interp_enabled);
|
||||
std::optional<uint32_t> getRollbackTarget();
|
||||
|
||||
bool checkAndRollback(DiffusionModule &diffusion, uint32_t &iter);
|
||||
void computeErrorMetrics(std::vector<std::vector<double>> &reference_values,
|
||||
std::vector<std::vector<double>> &surrogate_values,
|
||||
const std::vector<std::string> &species);
|
||||
|
||||
struct SpeciesErrorMetrics {
|
||||
std::vector<std::uint32_t> id;
|
||||
std::vector<std::vector<double>> mape;
|
||||
std::vector<std::vector<double>> rrmse;
|
||||
uint32_t iteration; // iterations in simulation after rollbacks
|
||||
uint32_t rollback_count;
|
||||
void processCheckpoint(DiffusionModule &diffusion, uint32_t ¤t_iter,
|
||||
const std::string &out_dir,
|
||||
const std::vector<std::string> &species);
|
||||
|
||||
SpeciesErrorMetrics(uint32_t num_cells, uint32_t species_count,
|
||||
uint32_t iter, uint32_t counter)
|
||||
: mape(num_cells, std::vector<double>(species_count, 0.0)),
|
||||
rrmse(num_cells, std::vector<double>(species_count, 0.0)),
|
||||
iteration(iter), rollback_count(counter) {}
|
||||
};
|
||||
std::optional<uint32_t>
|
||||
getRollbackTarget(const std::vector<std::string> &species);
|
||||
|
||||
void computeSpeciesErrorMetrics(
|
||||
std::vector<std::vector<double>> &reference_values,
|
||||
std::vector<std::vector<double>> &surrogate_values);
|
||||
|
||||
std::vector<SpeciesErrorMetrics> metricsHistory;
|
||||
bool shouldBcastFlags();
|
||||
|
||||
struct ControlSetup {
|
||||
std::string out_dir;
|
||||
std::uint32_t checkpoint_interval;
|
||||
std::uint32_t penalty_interval;
|
||||
std::uint32_t stabilization_interval;
|
||||
std::vector<std::string> species_names;
|
||||
std::vector<double> mape_threshold;
|
||||
std::vector<uint32_t> ctrl_cell_ids;
|
||||
};
|
||||
|
||||
void enableControlLogic(const ControlSetup &setup) {
|
||||
this->out_dir = setup.out_dir;
|
||||
this->checkpoint_interval = setup.checkpoint_interval;
|
||||
this->stabilization_interval = setup.stabilization_interval;
|
||||
this->species_names = setup.species_names;
|
||||
this->mape_threshold = setup.mape_threshold;
|
||||
this->ctrl_cell_ids = setup.ctrl_cell_ids;
|
||||
}
|
||||
|
||||
void applyControlLogic(DiffusionModule &diffusion, uint32_t &iter);
|
||||
|
||||
void writeCheckpointAndMetrics(DiffusionModule &diffusion, uint32_t iter);
|
||||
bool getFlushRequest() const { return flush_request; }
|
||||
void clearFlushRequest() { flush_request = false; }
|
||||
|
||||
auto getGlobalIteration() const noexcept { return global_iteration; }
|
||||
|
||||
void setChemistryModule(poet::ChemistryModule *c) { chem = c; }
|
||||
// void setChemistryModule(poet::ChemistryModule *c) { chem = c; }
|
||||
|
||||
std::vector<double> getMapeThreshold() const { return this->mape_threshold; }
|
||||
std::vector<double> getMapeThreshold() const {
|
||||
return this->config.mape_threshold;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> getCtrlCellIds() const { return this->ctrl_cell_ids; }
|
||||
|
||||
/* Profiling getters */
|
||||
|
||||
auto getUpdateCtrlLogicTime() const { return this->prep_t; }
|
||||
|
||||
auto getWriteCheckpointTime() const { return this->w_check_t; }
|
||||
|
||||
auto getReadCheckpointTime() const { return this->r_check_t; }
|
||||
|
||||
auto getWriteMetricsTime() const { return this->stats_t; }
|
||||
auto getUpdateCtrlLogicTime() const { return prep_t; }
|
||||
auto getWriteCheckpointTime() const { return w_check_t; }
|
||||
auto getReadCheckpointTime() const { return r_check_t; }
|
||||
auto getWriteMetricsTime() const { return stats_t; }
|
||||
|
||||
private:
|
||||
bool rollback_enabled = false;
|
||||
void updateStabilizationPhase(ChemistryModule &chem, bool dht_enabled,
|
||||
bool interp_enabled);
|
||||
|
||||
poet::ChemistryModule *chem = nullptr;
|
||||
void readCheckpoint(DiffusionModule &diffusion, uint32_t ¤t_iter,
|
||||
uint32_t rollback_iter, const std::string &out_dir);
|
||||
void writeCheckpoint(DiffusionModule &diffusion, uint32_t &iter,
|
||||
const std::string &out_dir);
|
||||
|
||||
uint32_t getRollbackIter();
|
||||
|
||||
ControlConfig config;
|
||||
|
||||
std::uint32_t stabilization_interval = 0;
|
||||
std::uint32_t penalty_interval = 0;
|
||||
std::uint32_t checkpoint_interval = 0;
|
||||
std::uint32_t global_iteration = 0;
|
||||
std::uint32_t rollback_count = 0;
|
||||
std::uint32_t sur_disabled_counter = 0;
|
||||
std::vector<double> mape_threshold;
|
||||
std::uint32_t disable_surr_counter = 0;
|
||||
std::vector<uint32_t> ctrl_cell_ids;
|
||||
std::uint32_t last_checkpoint_written = 0;
|
||||
std::uint32_t penalty_interval = 0;
|
||||
|
||||
std::vector<std::string> species_names;
|
||||
std::string out_dir;
|
||||
bool rollback_enabled = false;
|
||||
bool flush_request = false;
|
||||
bool stab_phase_ended = false;
|
||||
|
||||
bool bcast_flags = false;
|
||||
|
||||
std::vector<SpeciesErrorMetrics> metrics_history;
|
||||
|
||||
double prep_t = 0.;
|
||||
double r_check_t = 0.;
|
||||
double w_check_t = 0.;
|
||||
double stats_t = 0.;
|
||||
|
||||
/* Buffer for shuffled surrogate data */
|
||||
std::vector<double> sur_shuffled;
|
||||
};
|
||||
|
||||
} // namespace poet
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
|
||||
namespace poet
|
||||
{
|
||||
void writeStatsToCSV(const std::vector<ControlModule::SpeciesErrorMetrics> &all_stats,
|
||||
void writeStatsToCSV(const std::vector<SpeciesErrorMetrics> &all_stats,
|
||||
const std::vector<std::string> &species_names,
|
||||
const std::string &out_dir,
|
||||
const std::string &filename)
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
namespace poet
|
||||
{
|
||||
void writeStatsToCSV(const std::vector<ControlModule::SpeciesErrorMetrics> &all_stats,
|
||||
void writeStatsToCSV(const std::vector<SpeciesErrorMetrics> &all_stats,
|
||||
const std::vector<std::string> &species_names,
|
||||
const std::string &out_dir,
|
||||
const std::string &filename);
|
||||
|
||||
34
src/poet.cpp
34
src/poet.cpp
@ -252,10 +252,10 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) {
|
||||
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("timesteps"));
|
||||
params.checkpoint_interval =
|
||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval"));
|
||||
params.stabilization_interval =
|
||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("stabilization_interval"));
|
||||
params.penalty_interval =
|
||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("penalty_interval"));
|
||||
params.stab_interval =
|
||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("stab_interval"));
|
||||
params.zero_abs =
|
||||
Rcpp::as<double>(global_rt_setup->operator[]("zero_abs"));
|
||||
params.mape_threshold = Rcpp::as<std::vector<double>>(
|
||||
global_rt_setup->operator[]("mape_threshold"));
|
||||
params.ctrl_cell_ids = Rcpp::as<std::vector<uint32_t>>(
|
||||
@ -305,7 +305,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
|
||||
double dSimTime{0};
|
||||
for (uint32_t iter = 1; iter < maxiter + 1; iter++) {
|
||||
control.updateControlIteration(iter, params.use_dht, params.use_interp);
|
||||
control.beginIteration(chem, iter, params.use_dht, params.use_interp);
|
||||
|
||||
double start_t = MPI_Wtime();
|
||||
|
||||
@ -415,7 +415,9 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +
|
||||
std::to_string(maxiter));
|
||||
|
||||
control.applyControlLogic(diffusion, iter);
|
||||
control.processCheckpoint(diffusion, iter, params.out_dir,
|
||||
chem.getField().GetProps());
|
||||
|
||||
// MSG();
|
||||
} // END SIMULATION LOOP
|
||||
|
||||
@ -622,9 +624,6 @@ int main(int argc, char *argv[]) {
|
||||
|
||||
ChemistryModule chemistry(run_params.work_package_size,
|
||||
init_list.getChemistryInit(), MPI_COMM_WORLD);
|
||||
ControlModule control;
|
||||
chemistry.SetControlModule(&control);
|
||||
control.setChemistryModule(&chemistry);
|
||||
|
||||
const ChemistryModule::SurrogateSetup surr_setup = {
|
||||
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
||||
@ -646,16 +645,6 @@ int main(int argc, char *argv[]) {
|
||||
getControlCellIds(run_params.ctrl_cell_ids, 0, MPI_COMM_WORLD);
|
||||
chemistry.SetControlCellIds(run_params.ctrl_cell_ids);
|
||||
|
||||
const ControlModule::ControlSetup ctrl_setup = {
|
||||
run_params.out_dir, // added
|
||||
run_params.checkpoint_interval,
|
||||
run_params.penalty_interval,
|
||||
run_params.stabilization_interval,
|
||||
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
||||
run_params.mape_threshold};
|
||||
|
||||
control.enableControlLogic(ctrl_setup);
|
||||
|
||||
if (MY_RANK > 0) {
|
||||
chemistry.WorkerLoop();
|
||||
} else {
|
||||
@ -697,7 +686,14 @@ int main(int argc, char *argv[]) {
|
||||
DiffusionModule diffusion(init_list.getDiffusionInit(),
|
||||
init_list.getInitialGrid());
|
||||
|
||||
ControlConfig config(run_params.stab_interval,
|
||||
run_params.checkpoint_interval, run_params.zero_abs,
|
||||
run_params.mape_threshold);
|
||||
|
||||
ControlModule control(config);
|
||||
|
||||
chemistry.masterSetField(init_list.getInitialGrid());
|
||||
chemistry.SetControlModule(&control);
|
||||
|
||||
Rcpp::List profiling =
|
||||
RunMasterLoop(R, run_params, diffusion, chemistry, control);
|
||||
|
||||
@ -50,12 +50,11 @@ struct RuntimeParameters {
|
||||
std::string out_ext;
|
||||
|
||||
bool print_progress = false;
|
||||
std::uint32_t penalty_interval = 0;
|
||||
std::uint32_t stabilization_interval = 0;
|
||||
std::uint32_t stab_interval = 0;
|
||||
std::uint32_t checkpoint_interval = 0;
|
||||
std::uint32_t control_interval = 0;
|
||||
double zero_abs = 0.0;
|
||||
std::vector<double> mape_threshold;
|
||||
std::vector<uint32t_t> ctrl_cell_ids;
|
||||
std::vector<uint32_t> ctrl_cell_ids;
|
||||
|
||||
static constexpr std::uint32_t WORK_PACKAGE_SIZE_DEFAULT = 32;
|
||||
std::uint32_t work_package_size = WORK_PACKAGE_SIZE_DEFAULT;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user