mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-13 03:18:23 +01:00
Compare commits
2 Commits
97076cb7cd
...
36b6f8d859
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
36b6f8d859 | ||
|
|
be5b42392a |
@ -4,6 +4,8 @@ SOLUTION 1
|
||||
temperature 25
|
||||
pH 7
|
||||
pe 4
|
||||
Mg 1e-12
|
||||
Cl 2e-12
|
||||
PURE 1
|
||||
Calcite 0.0 1
|
||||
END
|
||||
|
||||
@ -6,6 +6,7 @@ mape_threshold <- rep(0.0035, 13)
|
||||
mape_threshold[5] <- 1 #Charge
|
||||
zero_abs <- 1e-13
|
||||
rb_limit <- 3
|
||||
rb_interval_limit <- 200
|
||||
|
||||
#ctrl_cell_ids <- seq(0, (400*400)/2 - 1, by = 401)
|
||||
#out_save <- seq(500, iterations, by = 500)
|
||||
@ -20,5 +21,6 @@ list(
|
||||
ctrl_interval = ctrl_interval,
|
||||
mape_threshold = mape_threshold,
|
||||
zero_abs = zero_abs,
|
||||
rb_limit = rb_limit
|
||||
rb_limit = rb_limit,
|
||||
rb_interval_limit = rb_interval_limit
|
||||
)
|
||||
@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=p1_eps0035_v2
|
||||
#SBATCH --output=p1_eps0035_v2_%j.out
|
||||
#SBATCH --error=p1_eps0035_v2_%j.err
|
||||
#SBATCH --job-name=p1_eps0035_200
|
||||
#SBATCH --output=p1_eps0035_200_%j.out
|
||||
#SBATCH --error=p1_eps0035_200_%j.err
|
||||
#SBATCH --partition=long
|
||||
#SBATCH --nodes=6
|
||||
#SBATCH --ntasks-per-node=24
|
||||
@ -15,5 +15,5 @@ module purge
|
||||
module load cmake gcc openmpi
|
||||
|
||||
#mpirun -n 144 ./poet dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_only_pqc
|
||||
mpirun -n 144 ./poet --interp dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 p1_eps0035_v2
|
||||
mpirun -n 144 ./poet --interp dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 p1_eps0035_200
|
||||
#mpirun -n 144 ./poet --interp barite_fgcs_4_new/barite_fgcs_4_new_rt.R barite_fgcs_4_new/barite_fgcs_4_new.qs2 barite
|
||||
@ -6,7 +6,7 @@
|
||||
namespace poet {
|
||||
|
||||
enum DHT_PROP_TYPES { DHT_TYPE_DEFAULT, DHT_TYPE_CHARGE, DHT_TYPE_TOTAL };
|
||||
enum CHEMISTRY_OUT_SOURCE { CHEM_PQC, CHEM_DHT, CHEM_INTERP, CHEM_AISURR };
|
||||
enum CHEMISTRY_OUT_SOURCE { CHEM_PQC, CHEM_DHT, CHEM_INTERP, CHEM_AISURR, CHEM_SKIP };
|
||||
|
||||
struct WorkPackage {
|
||||
std::size_t size;
|
||||
|
||||
@ -48,15 +48,14 @@ void poet::ChemistryModule::WorkerLoop() {
|
||||
case CHEM_FIELD_INIT: {
|
||||
ChemBCast(&this->prop_count, 1, MPI_UINT32_T);
|
||||
if (this->ai_surrogate_enabled) {
|
||||
this->ai_surrogate_validity_vector.resize(
|
||||
this->n_cells); // resize statt reserve?
|
||||
this->ai_surrogate_validity_vector.resize(this->n_cells); // resize statt reserve?
|
||||
}
|
||||
break;
|
||||
}
|
||||
case CHEM_AI_BCAST_VALIDITY: {
|
||||
// Receive the index vector of valid ai surrogate predictions
|
||||
MPI_Bcast(&this->ai_surrogate_validity_vector.front(), this->n_cells,
|
||||
MPI_INT, 0, this->group_comm);
|
||||
MPI_Bcast(&this->ai_surrogate_validity_vector.front(), this->n_cells, MPI_INT, 0,
|
||||
this->group_comm);
|
||||
break;
|
||||
}
|
||||
case CHEM_CTRL_ENABLE: {
|
||||
@ -130,17 +129,16 @@ void poet::ChemistryModule::WorkerProcessPkgs(struct worker_s &timings,
|
||||
}
|
||||
}
|
||||
void poet::ChemistryModule::copyPkgs(const WorkPackage &wp,
|
||||
std::vector<double> &mpi_buffer,
|
||||
std::size_t offset) {
|
||||
std::vector<double> &mpi_buffer,
|
||||
std::size_t offset) {
|
||||
for (std::size_t wp_i = 0; wp_i < wp.size; wp_i++) {
|
||||
std::copy(wp.output[wp_i].begin(), wp.output[wp_i].end(),
|
||||
mpi_buffer.begin() + offset + this->prop_count * wp_i);
|
||||
}
|
||||
}
|
||||
void poet::ChemistryModule::copyCtrlPkgs(const WorkPackage &pqc_wp,
|
||||
const WorkPackage &surr_wp,
|
||||
std::vector<double> &mpi_buffer,
|
||||
int &count) {
|
||||
const WorkPackage &surr_wp,
|
||||
std::vector<double> &mpi_buffer, int &count) {
|
||||
std::size_t wp_offset = surr_wp.size * this->prop_count;
|
||||
mpi_buffer.resize(count + wp_offset);
|
||||
|
||||
@ -162,8 +160,7 @@ void poet::ChemistryModule::copyCtrlPkgs(const WorkPackage &pqc_wp,
|
||||
count += wp_offset;
|
||||
}
|
||||
|
||||
void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
int double_count,
|
||||
void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, int double_count,
|
||||
struct worker_s &timings) {
|
||||
static int counter = 1;
|
||||
|
||||
@ -180,6 +177,9 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
int flags;
|
||||
std::vector<double> mpi_buffer(count);
|
||||
|
||||
const int CL_INDEX = 7;
|
||||
const double CL_THRESHOLD = 1e-10;
|
||||
|
||||
/* receive */
|
||||
MPI_Recv(mpi_buffer.data(), count, MPI_DOUBLE, 0, LOOP_WORK, this->group_comm,
|
||||
MPI_STATUS_IGNORE);
|
||||
@ -216,6 +216,16 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
mpi_buffer.begin() + this->prop_count * (wp_i + 1));
|
||||
}
|
||||
|
||||
/* skip simulation of cells cells where Cl concentration is below threshold */
|
||||
/*
|
||||
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
|
||||
if (s_curr_wp.input[wp_i][CL_INDEX] < CL_THRESHOLD) {
|
||||
s_curr_wp.mapping[wp_i] = CHEM_SKIP;
|
||||
s_curr_wp.output[wp_i] = s_curr_wp.input[wp_i];
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
// std::cout << this->comm_rank << ":" << counter++ << std::endl;
|
||||
if (dht_enabled || interp_enabled || stab_enabled) {
|
||||
dht->prepareKeys(s_curr_wp.input, dt);
|
||||
@ -250,8 +260,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
if (ctrl_enabled) {
|
||||
ctrl_cp_start = MPI_Wtime();
|
||||
for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) {
|
||||
s_curr_wp_control.output[wp_i] =
|
||||
std::vector<double>(this->prop_count, 0.0);
|
||||
s_curr_wp_control.output[wp_i] = std::vector<double>(this->prop_count, 0.0);
|
||||
s_curr_wp_control.mapping[wp_i] = CHEM_PQC;
|
||||
}
|
||||
ctrl_cp_end = MPI_Wtime();
|
||||
@ -260,8 +269,8 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
|
||||
phreeqc_time_start = MPI_Wtime();
|
||||
|
||||
WorkerRunWorkPackage(ctrl_enabled ? s_curr_wp_control : s_curr_wp,
|
||||
current_sim_time, dt);
|
||||
WorkerRunWorkPackage(ctrl_enabled ? s_curr_wp_control : s_curr_wp, current_sim_time,
|
||||
dt);
|
||||
|
||||
phreeqc_time_end = MPI_Wtime();
|
||||
|
||||
@ -278,8 +287,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
MPI_Request send_req;
|
||||
|
||||
int mpi_tag = ctrl_enabled ? LOOP_CTRL : LOOP_WORK;
|
||||
MPI_Isend(mpi_buffer.data(), count, MPI_DOUBLE, 0, mpi_tag, MPI_COMM_WORLD,
|
||||
&send_req);
|
||||
MPI_Isend(mpi_buffer.data(), count, MPI_DOUBLE, 0, mpi_tag, MPI_COMM_WORLD, &send_req);
|
||||
|
||||
if (dht_enabled || interp_enabled || stab_enabled) {
|
||||
/* write results to DHT */
|
||||
@ -297,19 +305,18 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
MPI_Wait(&send_req, MPI_STATUS_IGNORE);
|
||||
}
|
||||
|
||||
void poet::ChemistryModule::WorkerPostIter(MPI_Status &probe_status,
|
||||
uint32_t iteration) {
|
||||
void poet::ChemistryModule::WorkerPostIter(MPI_Status &probe_status, uint32_t iteration) {
|
||||
|
||||
int size, flush = 0;
|
||||
|
||||
MPI_Get_count(&probe_status, MPI_INT, &size);
|
||||
|
||||
if (size == 1) {
|
||||
MPI_Recv(&flush, size, MPI_INT, probe_status.MPI_SOURCE, LOOP_END,
|
||||
this->group_comm, MPI_STATUS_IGNORE);
|
||||
MPI_Recv(&flush, size, MPI_INT, probe_status.MPI_SOURCE, LOOP_END, this->group_comm,
|
||||
MPI_STATUS_IGNORE);
|
||||
} else {
|
||||
MPI_Recv(NULL, 0, MPI_INT, probe_status.MPI_SOURCE, LOOP_END,
|
||||
this->group_comm, MPI_STATUS_IGNORE);
|
||||
MPI_Recv(NULL, 0, MPI_INT, probe_status.MPI_SOURCE, LOOP_END, this->group_comm,
|
||||
MPI_STATUS_IGNORE);
|
||||
}
|
||||
|
||||
if (this->dht_enabled) {
|
||||
@ -333,8 +340,7 @@ void poet::ChemistryModule::WorkerPostIter(MPI_Status &probe_status,
|
||||
interp->dumpPHTState(out.str());
|
||||
}
|
||||
|
||||
const auto max_mean_idx =
|
||||
DHT_get_used_idx_factor(this->interp->getDHTObject(), 1);
|
||||
const auto max_mean_idx = DHT_get_used_idx_factor(this->interp->getDHTObject(), 1);
|
||||
|
||||
if (max_mean_idx >= 2 || flush) {
|
||||
DHT_flush(this->interp->getDHTObject());
|
||||
@ -366,21 +372,17 @@ void poet::ChemistryModule::WorkerWriteDHTDump(uint32_t iteration) {
|
||||
<< std::setw(this->file_pad) << iteration << ".dht";
|
||||
int res = dht->tableToFile(out.str().c_str());
|
||||
if (res != DHT_SUCCESS && this->comm_rank == 2)
|
||||
std::cerr
|
||||
<< "CPP: Worker: Error in writing current state of DHT to file.\n";
|
||||
std::cerr << "CPP: Worker: Error in writing current state of DHT to file.\n";
|
||||
else if (this->comm_rank == 2)
|
||||
std::cout << "CPP: Worker: Successfully written DHT to file " << out.str()
|
||||
<< "\n";
|
||||
std::cout << "CPP: Worker: Successfully written DHT to file " << out.str() << "\n";
|
||||
}
|
||||
|
||||
void poet::ChemistryModule::WorkerReadDHTDump(
|
||||
const std::string &dht_input_file) {
|
||||
void poet::ChemistryModule::WorkerReadDHTDump(const std::string &dht_input_file) {
|
||||
int res = dht->fileToTable((char *)dht_input_file.c_str());
|
||||
if (res != DHT_SUCCESS) {
|
||||
if (res == DHT_WRONG_FILE) {
|
||||
if (this->comm_rank == 1)
|
||||
std::cerr
|
||||
<< "CPP: Worker: Wrong file layout! Continue with empty DHT ...\n";
|
||||
std::cerr << "CPP: Worker: Wrong file layout! Continue with empty DHT ...\n";
|
||||
} else {
|
||||
if (this->comm_rank == 1)
|
||||
std::cerr << "CPP: Worker: Error in loading current state of DHT from "
|
||||
@ -394,8 +396,7 @@ void poet::ChemistryModule::WorkerReadDHTDump(
|
||||
}
|
||||
|
||||
void poet::ChemistryModule::WorkerRunWorkPackage(WorkPackage &work_package,
|
||||
double dSimTime,
|
||||
double dTimestep) {
|
||||
double dSimTime, double dTimestep) {
|
||||
|
||||
std::vector<std::vector<double>> inout_chem = work_package.input;
|
||||
std::vector<std::size_t> to_ignore;
|
||||
@ -406,8 +407,7 @@ void poet::ChemistryModule::WorkerRunWorkPackage(WorkPackage &work_package,
|
||||
}
|
||||
|
||||
// HACK: remove the first element (cell_id) before sending to phreeqc
|
||||
inout_chem[wp_id].erase(inout_chem[wp_id].begin(),
|
||||
inout_chem[wp_id].begin() + 1);
|
||||
inout_chem[wp_id].erase(inout_chem[wp_id].begin(), inout_chem[wp_id].begin() + 1);
|
||||
}
|
||||
|
||||
this->pqc_runner->run(inout_chem, dTimestep, to_ignore);
|
||||
@ -423,8 +423,7 @@ void poet::ChemistryModule::WorkerRunWorkPackage(WorkPackage &work_package,
|
||||
}
|
||||
}
|
||||
|
||||
void poet::ChemistryModule::WorkerPerfToMaster(int type,
|
||||
const struct worker_s &timings) {
|
||||
void poet::ChemistryModule::WorkerPerfToMaster(int type, const struct worker_s &timings) {
|
||||
switch (type) {
|
||||
case WORKER_PHREEQC: {
|
||||
MPI_Gather(&timings.phreeqc_t, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0,
|
||||
@ -432,13 +431,11 @@ void poet::ChemistryModule::WorkerPerfToMaster(int type,
|
||||
break;
|
||||
}
|
||||
case WORKER_CTRL_ITER: {
|
||||
MPI_Gather(&timings.ctrl_t, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0,
|
||||
this->group_comm);
|
||||
MPI_Gather(&timings.ctrl_t, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0, this->group_comm);
|
||||
break;
|
||||
}
|
||||
case WORKER_DHT_GET: {
|
||||
MPI_Gather(&timings.dht_get, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0,
|
||||
this->group_comm);
|
||||
MPI_Gather(&timings.dht_get, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0, this->group_comm);
|
||||
break;
|
||||
}
|
||||
case WORKER_DHT_FILL: {
|
||||
@ -447,8 +444,7 @@ void poet::ChemistryModule::WorkerPerfToMaster(int type,
|
||||
break;
|
||||
}
|
||||
case WORKER_IDLE: {
|
||||
MPI_Gather(&timings.idle_t, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0,
|
||||
this->group_comm);
|
||||
MPI_Gather(&timings.idle_t, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0, this->group_comm);
|
||||
break;
|
||||
}
|
||||
case WORKER_IP_WRITE: {
|
||||
@ -484,15 +480,14 @@ void poet::ChemistryModule::WorkerMetricsToMaster(int type) {
|
||||
|
||||
MPI_Comm &group_comm = this->group_comm;
|
||||
|
||||
auto reduce_and_send = [&worker_rank, &worker_comm, &group_comm](
|
||||
std::vector<std::uint32_t> &send_buffer, int tag) {
|
||||
auto reduce_and_send = [&worker_rank, &worker_comm,
|
||||
&group_comm](std::vector<std::uint32_t> &send_buffer, int tag) {
|
||||
std::vector<uint32_t> to_master(send_buffer.size());
|
||||
MPI_Reduce(send_buffer.data(), to_master.data(), send_buffer.size(),
|
||||
MPI_UINT32_T, MPI_SUM, 0, worker_comm);
|
||||
MPI_Reduce(send_buffer.data(), to_master.data(), send_buffer.size(), MPI_UINT32_T,
|
||||
MPI_SUM, 0, worker_comm);
|
||||
|
||||
if (worker_rank == 0) {
|
||||
MPI_Send(to_master.data(), to_master.size(), MPI_UINT32_T, 0, tag,
|
||||
group_comm);
|
||||
MPI_Send(to_master.data(), to_master.size(), MPI_UINT32_T, 0, tag, group_comm);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ void poet::ControlModule::beginIteration(const uint32_t &iter, const bool &dht_e
|
||||
void poet::ControlModule::updateSurrState(bool dht_enabled, bool interp_enabled) {
|
||||
|
||||
bool in_warmup = (global_iter <= config.ctrl_interval);
|
||||
bool rb_limit_reached = (rb_count >= config.rb_limit);
|
||||
bool rb_limit_reached = rbLimitReached();
|
||||
|
||||
if (rb_enabled && stab_countdown > 0 && !rb_limit_reached) {
|
||||
--stab_countdown;
|
||||
@ -54,9 +54,18 @@ void poet::ControlModule::updateSurrState(bool dht_enabled, bool interp_enabled)
|
||||
} else {
|
||||
std::cout << "In stabilization phase." << std::endl;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (rb_count > 0 && !rb_enabled && !in_warmup) {
|
||||
surr_active++;
|
||||
if (surr_active > config.rb_interval_limit) {
|
||||
surr_active = 0;
|
||||
rb_count -= 1;
|
||||
std::cout << "Rollback count reset to: " << rb_count << "." << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
/* enable user-requested surrogates */
|
||||
chem->SetStabEnabled(false);
|
||||
chem->SetDhtEnabled(dht_enabled);
|
||||
@ -80,7 +89,8 @@ void poet::ControlModule::readCheckpoint(uint32_t ¤t_iter, uint32_t rollba
|
||||
double r_check_a, r_check_b;
|
||||
r_check_a = MPI_Wtime();
|
||||
Checkpoint_s checkpoint_read{.field = chem->getField()};
|
||||
read_checkpoint(out_dir, "checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read);
|
||||
read_checkpoint(out_dir, "checkpoint" + std::to_string(rollback_iter) + ".hdf5",
|
||||
checkpoint_read);
|
||||
current_iter = checkpoint_read.iteration;
|
||||
r_check_b = MPI_Wtime();
|
||||
r_check_t += r_check_b - r_check_a;
|
||||
@ -102,20 +112,22 @@ void poet::ControlModule::writeMetrics(const std::string &out_dir,
|
||||
|
||||
uint32_t poet::ControlModule::calcRbIter() {
|
||||
|
||||
uint32_t last_iter = ((global_iter - 1) / config.chkpt_interval) * config.chkpt_interval;
|
||||
uint32_t last_iter =
|
||||
((global_iter - 1) / config.chkpt_interval) * config.chkpt_interval;
|
||||
|
||||
uint32_t rb_iter = (last_iter <= last_chkpt_written) ? last_iter : last_chkpt_written;
|
||||
return rb_iter;
|
||||
}
|
||||
|
||||
std::optional<uint32_t> poet::ControlModule::findRbTarget(const std::vector<std::string> &species) {
|
||||
std::optional<uint32_t>
|
||||
poet::ControlModule::findRbTarget(const std::vector<std::string> &species) {
|
||||
|
||||
if (metrics_history.empty()) {
|
||||
std::cout << "No error history yet, skipping rollback check." << std::endl;
|
||||
flush_request = false;
|
||||
return std::nullopt;
|
||||
}
|
||||
if (rb_count > config.rb_limit) {
|
||||
if (rbLimitReached()) {
|
||||
std::cout << "Rollback limit reached, skipping control logic." << std::endl;
|
||||
flush_request = false;
|
||||
return std::nullopt;
|
||||
@ -126,14 +138,19 @@ std::optional<uint32_t> poet::ControlModule::findRbTarget(const std::vector<std:
|
||||
|
||||
double r_check_a, r_check_b;
|
||||
const auto &mape = metrics_history.back().mape;
|
||||
for (uint32_t i = 0; i < species.size(); ++i) {
|
||||
for (uint32_t sp_idx = 0; sp_idx < species.size(); ++sp_idx) {
|
||||
|
||||
if (mape[i] == 0) {
|
||||
if (mape[sp_idx] == 0) {
|
||||
continue;
|
||||
}
|
||||
if (mape[i] > config.mape_threshold[i]) {
|
||||
std::cout << "Species " << species[i] << " MAPE=" << mape[i]
|
||||
<< " threshold=" << config.mape_threshold[i] << std::endl;
|
||||
/* skip Charge */
|
||||
if (sp_idx == 4) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (mape[sp_idx] > config.mape_threshold[sp_idx]) {
|
||||
std::cout << "Species " << species[sp_idx] << " MAPE=" << mape[sp_idx]
|
||||
<< " threshold=" << config.mape_threshold[sp_idx] << std::endl;
|
||||
|
||||
if (last_chkpt_written == 0) {
|
||||
std::cout << " Threshold exceeded but no checkpoint exists yet." << std::endl;
|
||||
@ -141,9 +158,10 @@ std::optional<uint32_t> poet::ControlModule::findRbTarget(const std::vector<std:
|
||||
}
|
||||
// rb_enabled = true;
|
||||
flush_request = true;
|
||||
std::cout << "Threshold exceeded " << species[i] << " has MAPE = " << std::to_string(mape[i])
|
||||
<< " exceeding threshold = " << std::to_string(config.mape_threshold[i])
|
||||
<< std::endl;
|
||||
std::cout << "Threshold exceeded " << species[sp_idx]
|
||||
<< " has MAPE = " << std::to_string(mape[sp_idx])
|
||||
<< " exceeding threshold = "
|
||||
<< std::to_string(config.mape_threshold[sp_idx]) << std::endl;
|
||||
return calcRbIter();
|
||||
}
|
||||
}
|
||||
@ -157,16 +175,16 @@ void poet::ControlModule::computeMetrics(const std::vector<double> &reference_va
|
||||
const uint32_t size_per_prop,
|
||||
const std::vector<std::string> &species) {
|
||||
|
||||
if (rb_count > config.rb_limit) {
|
||||
if (rbLimitReached()) {
|
||||
return;
|
||||
}
|
||||
|
||||
SpeciesMetrics metrics(species.size(), global_iter, rb_count);
|
||||
|
||||
for (uint32_t i = 0; i < species.size(); ++i) {
|
||||
for (uint32_t sp_idx = 0; sp_idx < species.size(); ++sp_idx) {
|
||||
double err_sum = 0.0;
|
||||
double sqr_err_sum = 0.0;
|
||||
uint32_t base_idx = i * size_per_prop;
|
||||
uint32_t base_idx = sp_idx * size_per_prop;
|
||||
|
||||
for (uint32_t j = 0; j < size_per_prop; ++j) {
|
||||
const double ref_value = reference_values[base_idx + j];
|
||||
@ -187,16 +205,17 @@ void poet::ControlModule::computeMetrics(const std::vector<double> &reference_va
|
||||
sqr_err_sum += alpha * alpha;
|
||||
}
|
||||
}
|
||||
metrics.mape[i] = 100.0 * (err_sum / size_per_prop);
|
||||
metrics.rrmse[i] = std::sqrt(sqr_err_sum / size_per_prop);
|
||||
metrics.mape[sp_idx] = 100.0 * (err_sum / size_per_prop);
|
||||
metrics.rrmse[sp_idx] = std::sqrt(sqr_err_sum / size_per_prop);
|
||||
}
|
||||
metrics_history.push_back(metrics);
|
||||
}
|
||||
|
||||
void poet::ControlModule::processCheckpoint(uint32_t ¤t_iter, const std::string &out_dir,
|
||||
void poet::ControlModule::processCheckpoint(uint32_t ¤t_iter,
|
||||
const std::string &out_dir,
|
||||
const std::vector<std::string> &species) {
|
||||
|
||||
if (!ctrl_active || rb_count > config.rb_limit) {
|
||||
if (!ctrl_active || rbLimitReached()) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -208,19 +227,20 @@ void poet::ControlModule::processCheckpoint(uint32_t ¤t_iter, const std::s
|
||||
rb_count++;
|
||||
stab_countdown = config.ctrl_interval;
|
||||
|
||||
std::cout << "Restored checkpoint " << std::to_string(target) << ", surrogates disabled for "
|
||||
<< config.ctrl_interval << std::endl;
|
||||
std::cout << "Restored checkpoint " << std::to_string(target)
|
||||
<< ", surrogates disabled for " << config.ctrl_interval << std::endl;
|
||||
} else {
|
||||
writeCheckpoint(global_iter, out_dir);
|
||||
}
|
||||
}
|
||||
|
||||
bool poet::ControlModule::needsFlagBcast() const {
|
||||
if (rb_count > config.rb_limit) {
|
||||
return false;
|
||||
}
|
||||
if (global_iter == 1 || global_iter % config.ctrl_interval == 1) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
return (config.rb_limit > 0) && !rbLimitReached();
|
||||
}
|
||||
|
||||
inline bool poet::ControlModule::rbLimitReached() const {
|
||||
/* rollback is completly disabled */
|
||||
if (config.rb_limit == 0)
|
||||
return false;
|
||||
return rb_count >= config.rb_limit;
|
||||
}
|
||||
|
||||
@ -17,6 +17,7 @@ struct ControlConfig {
|
||||
uint32_t ctrl_interval = 0;
|
||||
uint32_t chkpt_interval = 0;
|
||||
uint32_t rb_limit = 0;
|
||||
uint32_t rb_interval_limit = 0;
|
||||
double zero_abs = 0.0;
|
||||
std::vector<double> mape_threshold;
|
||||
};
|
||||
@ -28,8 +29,7 @@ struct SpeciesMetrics {
|
||||
uint32_t rb_count = 0;
|
||||
|
||||
SpeciesMetrics(uint32_t n_species, uint32_t iter, uint32_t count)
|
||||
: mape(n_species, 0.0), rrmse(n_species, 0.0), iteration(iter),
|
||||
rb_count(count) {}
|
||||
: mape(n_species, 0.0), rrmse(n_species, 0.0), iteration(iter), rb_count(count) {}
|
||||
};
|
||||
|
||||
class ControlModule {
|
||||
@ -41,27 +41,22 @@ public:
|
||||
|
||||
void writeCheckpoint(uint32_t &iter, const std::string &out_dir);
|
||||
|
||||
void writeMetrics(const std::string &out_dir,
|
||||
const std::vector<std::string> &species);
|
||||
void writeMetrics(const std::string &out_dir, const std::vector<std::string> &species);
|
||||
|
||||
std::optional<uint32_t> findRbTarget();
|
||||
|
||||
void computeMetrics(const std::vector<double> &reference_values,
|
||||
const std::vector<double> &surrogate_values,
|
||||
const uint32_t size_per_prop,
|
||||
const std::vector<std::string> &species);
|
||||
const std::vector<double> &surrogate_values,
|
||||
const uint32_t size_per_prop,
|
||||
const std::vector<std::string> &species);
|
||||
|
||||
void processCheckpoint(uint32_t ¤t_iter,
|
||||
const std::string &out_dir,
|
||||
void processCheckpoint(uint32_t ¤t_iter, const std::string &out_dir,
|
||||
const std::vector<std::string> &species);
|
||||
|
||||
std::optional<uint32_t>
|
||||
findRbTarget(const std::vector<std::string> &species);
|
||||
std::optional<uint32_t> findRbTarget(const std::vector<std::string> &species);
|
||||
|
||||
bool needsFlagBcast() const;
|
||||
bool isCtrlIntervalActive() const {
|
||||
return this->ctrl_active;
|
||||
}
|
||||
bool isCtrlIntervalActive() const { return this->ctrl_active; }
|
||||
|
||||
bool getFlushRequest() const { return flush_request; }
|
||||
void clearFlushRequest() { flush_request = false; }
|
||||
@ -76,17 +71,20 @@ public:
|
||||
private:
|
||||
void updateSurrState(bool dht_enabled, bool interp_enabled);
|
||||
|
||||
void readCheckpoint(uint32_t ¤t_iter,
|
||||
uint32_t rollback_iter, const std::string &out_dir);
|
||||
void readCheckpoint(uint32_t ¤t_iter, uint32_t rollback_iter,
|
||||
const std::string &out_dir);
|
||||
|
||||
uint32_t calcRbIter();
|
||||
|
||||
inline bool rbLimitReached() const;
|
||||
|
||||
ControlConfig config;
|
||||
ChemistryModule *chem = nullptr;
|
||||
|
||||
std::uint32_t global_iter = 0;
|
||||
std::uint32_t rb_count = 0;
|
||||
std::uint32_t stab_countdown = 0;
|
||||
std::uint32_t surr_active = 0;
|
||||
std::uint32_t last_chkpt_written = 0;
|
||||
|
||||
bool rb_enabled = false;
|
||||
|
||||
114
src/poet.cpp
114
src/poet.cpp
@ -99,9 +99,8 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) {
|
||||
"Print progress bar during chemical simulation");
|
||||
|
||||
/*Parse work package size*/
|
||||
app.add_option(
|
||||
"-w,--work-package-size", params.work_package_size,
|
||||
"Work package size to distribute to each worker for chemistry module")
|
||||
app.add_option("-w,--work-package-size", params.work_package_size,
|
||||
"Work package size to distribute to each worker for chemistry module")
|
||||
->check(CLI::PositiveNumber)
|
||||
->default_val(RuntimeParameters::WORK_PACKAGE_SIZE_DEFAULT);
|
||||
|
||||
@ -112,9 +111,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) {
|
||||
|
||||
// cout << "CPP: DHT is " << ( dht_enabled ? "ON" : "OFF" ) << '\n';
|
||||
|
||||
dht_group
|
||||
->add_option("--dht-size", params.dht_size,
|
||||
"DHT size per process in Megabyte")
|
||||
dht_group->add_option("--dht-size", params.dht_size, "DHT size per process in Megabyte")
|
||||
->check(CLI::PositiveNumber)
|
||||
->default_val(RuntimeParameters::DHT_SIZE_DEFAULT);
|
||||
// cout << "CPP: DHT size per process (Byte) = " << dht_size_per_process <<
|
||||
@ -140,9 +137,8 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) {
|
||||
->check(CLI::PositiveNumber)
|
||||
->default_val(RuntimeParameters::INTERP_MIN_ENTRIES_DEFAULT);
|
||||
interp_group
|
||||
->add_option(
|
||||
"--interp-bucket-entries", params.interp_bucket_entries,
|
||||
"Maximum number of entries in each bucket of the interpolation table")
|
||||
->add_option("--interp-bucket-entries", params.interp_bucket_entries,
|
||||
"Maximum number of entries in each bucket of the interpolation table")
|
||||
->check(CLI::PositiveNumber)
|
||||
->default_val(RuntimeParameters::INTERP_BUCKET_ENTRIES_DEFAULT);
|
||||
|
||||
@ -152,25 +148,21 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) {
|
||||
app.add_flag("--rds", params.as_rds,
|
||||
"Save output as .rds file instead of default .qs2");
|
||||
|
||||
app.add_flag("--qs", params.as_qs,
|
||||
"Save output as .qs file instead of default .qs2");
|
||||
app.add_flag("--qs", params.as_qs, "Save output as .qs file instead of default .qs2");
|
||||
|
||||
std::string init_file;
|
||||
std::string runtime_file;
|
||||
|
||||
app.add_option("runtime_file", runtime_file,
|
||||
"Runtime R script defining the simulation")
|
||||
app.add_option("runtime_file", runtime_file, "Runtime R script defining the simulation")
|
||||
->required()
|
||||
->check(CLI::ExistingFile);
|
||||
|
||||
app.add_option(
|
||||
"init_file", init_file,
|
||||
"Initial R script defining the simulation, produced by poet_init")
|
||||
app.add_option("init_file", init_file,
|
||||
"Initial R script defining the simulation, produced by poet_init")
|
||||
->required()
|
||||
->check(CLI::ExistingFile);
|
||||
|
||||
app.add_option("out_dir", params.out_dir,
|
||||
"Output directory of the simulation")
|
||||
app.add_option("out_dir", params.out_dir, "Output directory of the simulation")
|
||||
->required();
|
||||
|
||||
try {
|
||||
@ -202,8 +194,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) {
|
||||
// << simparams.dht_significant_digits);
|
||||
// MSG("DHT logarithm before rounding: "
|
||||
// << (simparams.dht_log ? "ON" : "OFF"));
|
||||
MSG("DHT size per process (Megabyte) = " +
|
||||
std::to_string(params.dht_size));
|
||||
MSG("DHT size per process (Megabyte) = " + std::to_string(params.dht_size));
|
||||
MSG("DHT save snapshots is " + BOOL_PRINT(params.dht_snaps));
|
||||
// MSG("DHT load file is " + chem_params.dht_file);
|
||||
}
|
||||
@ -212,8 +203,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) {
|
||||
MSG("PHT interpolation enabled: " + BOOL_PRINT(params.use_interp));
|
||||
MSG("PHT interp-size = " + std::to_string(params.interp_size));
|
||||
MSG("PHT interp-min = " + std::to_string(params.interp_min_entries));
|
||||
MSG("PHT interp-bucket-entries = " +
|
||||
std::to_string(params.interp_bucket_entries));
|
||||
MSG("PHT interp-bucket-entries = " + std::to_string(params.interp_bucket_entries));
|
||||
}
|
||||
}
|
||||
// chem_params.dht_outdir = out_dir;
|
||||
@ -253,10 +243,11 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) {
|
||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("ctrl_interval"));
|
||||
params.chkpt_interval =
|
||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("chkpt_interval"));
|
||||
params.rb_limit =
|
||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("rb_limit"));
|
||||
params.mape_threshold = Rcpp::as<std::vector<double>>(
|
||||
global_rt_setup->operator[]("mape_threshold"));
|
||||
params.rb_limit = Rcpp::as<uint32_t>(global_rt_setup->operator[]("rb_limit"));
|
||||
params.rb_interval_limit =
|
||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("rb_interval_limit"));
|
||||
params.mape_threshold =
|
||||
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("mape_threshold"));
|
||||
params.zero_abs = Rcpp::as<double>(global_rt_setup->operator[]("zero_abs"));
|
||||
} catch (const std::exception &e) {
|
||||
ERRMSG("Error while parsing R scripts: " + std::string(e.what()));
|
||||
@ -278,16 +269,15 @@ void call_master_iter_end(RInside &R, const Field &trans, const Field &chem) {
|
||||
R["TMP"] = Rcpp::wrap(chem.AsVector());
|
||||
R["TMP_PROPS"] = Rcpp::wrap(chem.GetProps());
|
||||
R.parseEval(std::string("state_C <- setNames(data.frame(matrix(TMP, nrow=" +
|
||||
std::to_string(chem.GetRequestedVecSize()) +
|
||||
")), TMP_PROPS)"));
|
||||
std::to_string(chem.GetRequestedVecSize()) + ")), TMP_PROPS)"));
|
||||
R["setup"] = *global_rt_setup;
|
||||
R.parseEval("setup <- master_iteration_end(setup, state_T, state_C)");
|
||||
*global_rt_setup = R["setup"];
|
||||
}
|
||||
|
||||
static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
DiffusionModule &diffusion,
|
||||
ChemistryModule &chem, ControlModule &control) {
|
||||
DiffusionModule &diffusion, ChemistryModule &chem,
|
||||
ControlModule &control) {
|
||||
|
||||
/* Iteration Count is dynamic, retrieving value from R (is only needed by
|
||||
* master for the following loop) */
|
||||
@ -327,10 +317,9 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
double ai_start_t = MPI_Wtime();
|
||||
// Save current values from the tug field as predictor for the ai step
|
||||
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
|
||||
R.parseEval(
|
||||
std::string("predictors <- setNames(data.frame(matrix(TMP, nrow=" +
|
||||
std::to_string(chem.getField().GetRequestedVecSize()) +
|
||||
")), TMP_PROPS)"));
|
||||
R.parseEval(std::string("predictors <- setNames(data.frame(matrix(TMP, nrow=" +
|
||||
std::to_string(chem.getField().GetRequestedVecSize()) +
|
||||
")), TMP_PROPS)"));
|
||||
R.parseEval("predictors <- predictors[ai_surrogate_species]");
|
||||
|
||||
// Apply preprocessing
|
||||
@ -339,8 +328,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
|
||||
// Predict
|
||||
MSG("AI Prediction");
|
||||
R.parseEval(
|
||||
"aipreds_scaled <- prediction_step(model, predictors_scaled)");
|
||||
R.parseEval("aipreds_scaled <- prediction_step(model, predictors_scaled)");
|
||||
|
||||
// Apply postprocessing
|
||||
MSG("AI Postprocessing");
|
||||
@ -348,8 +336,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
|
||||
// Validate prediction and write valid predictions to chem field
|
||||
MSG("AI Validation");
|
||||
R.parseEval(
|
||||
"validity_vector <- validate_predictions(predictors, aipreds)");
|
||||
R.parseEval("validity_vector <- validate_predictions(predictors, aipreds)");
|
||||
|
||||
MSG("AI Marking accepted");
|
||||
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
|
||||
@ -361,9 +348,8 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
validity_vector)");
|
||||
|
||||
MSG("AI Set Field");
|
||||
Field predictions_field =
|
||||
Field(R.parseEval("nrow(predictors)"), RTempField,
|
||||
R.parseEval("colnames(predictors)"));
|
||||
Field predictions_field = Field(R.parseEval("nrow(predictors)"), RTempField,
|
||||
R.parseEval("colnames(predictors)"));
|
||||
|
||||
MSG("AI Update");
|
||||
chem.getField().update(predictions_field);
|
||||
@ -378,10 +364,9 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
double ai_start_t = MPI_Wtime();
|
||||
|
||||
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
|
||||
R.parseEval(
|
||||
std::string("targets <- setNames(data.frame(matrix(TMP, nrow=" +
|
||||
std::to_string(chem.getField().GetRequestedVecSize()) +
|
||||
")), TMP_PROPS)"));
|
||||
R.parseEval(std::string("targets <- setNames(data.frame(matrix(TMP, nrow=" +
|
||||
std::to_string(chem.getField().GetRequestedVecSize()) +
|
||||
")), TMP_PROPS)"));
|
||||
R.parseEval("targets <- targets[ai_surrogate_species]");
|
||||
|
||||
// TODO: Check how to get the correct columns
|
||||
@ -414,8 +399,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
std::to_string(maxiter));
|
||||
|
||||
if (control.isCtrlIntervalActive()) {
|
||||
control.processCheckpoint(iter, params.out_dir,
|
||||
chem.getField().GetProps());
|
||||
control.processCheckpoint(iter, params.out_dir, chem.getField().GetProps());
|
||||
control.writeMetrics(params.out_dir, chem.getField().GetProps());
|
||||
}
|
||||
// MSG();
|
||||
@ -452,16 +436,12 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
//}
|
||||
|
||||
if (params.use_interp) {
|
||||
chem_profiling["interp_w"] =
|
||||
Rcpp::wrap(chem.GetWorkerInterpolationWriteTimings());
|
||||
chem_profiling["interp_r"] =
|
||||
Rcpp::wrap(chem.GetWorkerInterpolationReadTimings());
|
||||
chem_profiling["interp_g"] =
|
||||
Rcpp::wrap(chem.GetWorkerInterpolationGatherTimings());
|
||||
chem_profiling["interp_w"] = Rcpp::wrap(chem.GetWorkerInterpolationWriteTimings());
|
||||
chem_profiling["interp_r"] = Rcpp::wrap(chem.GetWorkerInterpolationReadTimings());
|
||||
chem_profiling["interp_g"] = Rcpp::wrap(chem.GetWorkerInterpolationGatherTimings());
|
||||
chem_profiling["interp_fc"] =
|
||||
Rcpp::wrap(chem.GetWorkerInterpolationFunctionCallTimings());
|
||||
chem_profiling["interp_calls"] =
|
||||
Rcpp::wrap(chem.GetWorkerInterpolationCalls());
|
||||
chem_profiling["interp_calls"] = Rcpp::wrap(chem.GetWorkerInterpolationCalls());
|
||||
chem_profiling["interp_cached"] = Rcpp::wrap(chem.GetWorkerPHTCacheHits());
|
||||
}
|
||||
|
||||
@ -476,8 +456,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
return profiling;
|
||||
}
|
||||
|
||||
std::vector<std::string> getSpeciesNames(const Field &&field, int root,
|
||||
MPI_Comm comm) {
|
||||
std::vector<std::string> getSpeciesNames(const Field &&field, int root, MPI_Comm comm) {
|
||||
std::uint32_t n_elements;
|
||||
std::uint32_t n_string_size;
|
||||
|
||||
@ -494,8 +473,8 @@ std::vector<std::string> getSpeciesNames(const Field &&field, int root,
|
||||
for (std::uint32_t i = 0; i < n_elements; i++) {
|
||||
n_string_size = field.GetProps()[i].size();
|
||||
MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
|
||||
MPI_Bcast(const_cast<char *>(field.GetProps()[i].c_str()), n_string_size,
|
||||
MPI_CHAR, root, MPI_COMM_WORLD);
|
||||
MPI_Bcast(const_cast<char *>(field.GetProps()[i].c_str()), n_string_size, MPI_CHAR,
|
||||
root, MPI_COMM_WORLD);
|
||||
}
|
||||
|
||||
return field.GetProps();
|
||||
@ -609,8 +588,8 @@ int main(int argc, char *argv[]) {
|
||||
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
|
||||
ChemistryModule chemistry(run_params.work_package_size,
|
||||
init_list.getChemistryInit(), MPI_COMM_WORLD);
|
||||
ChemistryModule chemistry(run_params.work_package_size, init_list.getChemistryInit(),
|
||||
MPI_COMM_WORLD);
|
||||
|
||||
// ControlModule control;
|
||||
// chemistry.SetControlModule(&control);
|
||||
@ -633,8 +612,8 @@ int main(int argc, char *argv[]) {
|
||||
chemistry.masterEnableSurrogates(surr_setup);
|
||||
|
||||
ControlConfig config(run_params.ctrl_interval, run_params.chkpt_interval,
|
||||
run_params.rb_limit, run_params.zero_abs,
|
||||
run_params.mape_threshold);
|
||||
run_params.rb_limit, run_params.rb_interval_limit,
|
||||
run_params.zero_abs, run_params.mape_threshold);
|
||||
|
||||
ControlModule control(config, &chemistry);
|
||||
|
||||
@ -660,8 +639,7 @@ int main(int argc, char *argv[]) {
|
||||
/* Incorporate ai surrogate from R */
|
||||
R.parseEvalQ(ai_surrogate_r_library);
|
||||
/* Use dht species for model input and output */
|
||||
R["ai_surrogate_species"] =
|
||||
init_list.getChemistryInit().dht_species.getNames();
|
||||
R["ai_surrogate_species"] = init_list.getChemistryInit().dht_species.getNames();
|
||||
|
||||
const std::string ai_surrogate_input_script =
|
||||
init_list.getChemistryInit().ai_surrogate_input_script;
|
||||
@ -678,13 +656,11 @@ int main(int argc, char *argv[]) {
|
||||
|
||||
// MPI_Barrier(MPI_COMM_WORLD);
|
||||
|
||||
DiffusionModule diffusion(init_list.getDiffusionInit(),
|
||||
init_list.getInitialGrid());
|
||||
DiffusionModule diffusion(init_list.getDiffusionInit(), init_list.getInitialGrid());
|
||||
|
||||
chemistry.masterSetField(init_list.getInitialGrid());
|
||||
|
||||
Rcpp::List profiling =
|
||||
RunMasterLoop(R, run_params, diffusion, chemistry, control);
|
||||
Rcpp::List profiling = RunMasterLoop(R, run_params, diffusion, chemistry, control);
|
||||
|
||||
MSG("finished simulation loop");
|
||||
|
||||
|
||||
@ -54,6 +54,7 @@ struct RuntimeParameters {
|
||||
std::uint32_t chkpt_interval = 0;
|
||||
std::uint32_t ctrl_interval = 0;
|
||||
std::uint32_t rb_limit = 0;
|
||||
std::uint32_t rb_interval_limit = 0;
|
||||
std::vector<double> mape_threshold;
|
||||
double zero_abs = 0.0;
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user