Compare commits

...

2 Commits

Author SHA1 Message Date
rastogi
36b6f8d859 Added Cl threshold in WorkerDoWork() 2025-12-09 16:54:04 +01:00
rastogi
be5b42392a feat: interval limit between rollbacks 2025-12-09 12:48:48 +01:00
9 changed files with 166 additions and 172 deletions

View File

@ -4,6 +4,8 @@ SOLUTION 1
temperature 25
pH 7
pe 4
Mg 1e-12
Cl 2e-12
PURE 1
Calcite 0.0 1
END

View File

@ -6,6 +6,7 @@ mape_threshold <- rep(0.0035, 13)
mape_threshold[5] <- 1 #Charge
zero_abs <- 1e-13
rb_limit <- 3
rb_interval_limit <- 200
#ctrl_cell_ids <- seq(0, (400*400)/2 - 1, by = 401)
#out_save <- seq(500, iterations, by = 500)
@ -20,5 +21,6 @@ list(
ctrl_interval = ctrl_interval,
mape_threshold = mape_threshold,
zero_abs = zero_abs,
rb_limit = rb_limit
rb_limit = rb_limit,
rb_interval_limit = rb_interval_limit
)

View File

@ -1,7 +1,7 @@
#!/bin/bash
#SBATCH --job-name=p1_eps0035_v2
#SBATCH --output=p1_eps0035_v2_%j.out
#SBATCH --error=p1_eps0035_v2_%j.err
#SBATCH --job-name=p1_eps0035_200
#SBATCH --output=p1_eps0035_200_%j.out
#SBATCH --error=p1_eps0035_200_%j.err
#SBATCH --partition=long
#SBATCH --nodes=6
#SBATCH --ntasks-per-node=24
@ -15,5 +15,5 @@ module purge
module load cmake gcc openmpi
#mpirun -n 144 ./poet dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_only_pqc
mpirun -n 144 ./poet --interp dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 p1_eps0035_v2
mpirun -n 144 ./poet --interp dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 p1_eps0035_200
#mpirun -n 144 ./poet --interp barite_fgcs_4_new/barite_fgcs_4_new_rt.R barite_fgcs_4_new/barite_fgcs_4_new.qs2 barite

View File

@ -6,7 +6,7 @@
namespace poet {
enum DHT_PROP_TYPES { DHT_TYPE_DEFAULT, DHT_TYPE_CHARGE, DHT_TYPE_TOTAL };
enum CHEMISTRY_OUT_SOURCE { CHEM_PQC, CHEM_DHT, CHEM_INTERP, CHEM_AISURR };
enum CHEMISTRY_OUT_SOURCE { CHEM_PQC, CHEM_DHT, CHEM_INTERP, CHEM_AISURR, CHEM_SKIP };
struct WorkPackage {
std::size_t size;

View File

@ -48,15 +48,14 @@ void poet::ChemistryModule::WorkerLoop() {
case CHEM_FIELD_INIT: {
ChemBCast(&this->prop_count, 1, MPI_UINT32_T);
if (this->ai_surrogate_enabled) {
this->ai_surrogate_validity_vector.resize(
this->n_cells); // resize statt reserve?
this->ai_surrogate_validity_vector.resize(this->n_cells); // resize statt reserve?
}
break;
}
case CHEM_AI_BCAST_VALIDITY: {
// Receive the index vector of valid ai surrogate predictions
MPI_Bcast(&this->ai_surrogate_validity_vector.front(), this->n_cells,
MPI_INT, 0, this->group_comm);
MPI_Bcast(&this->ai_surrogate_validity_vector.front(), this->n_cells, MPI_INT, 0,
this->group_comm);
break;
}
case CHEM_CTRL_ENABLE: {
@ -130,17 +129,16 @@ void poet::ChemistryModule::WorkerProcessPkgs(struct worker_s &timings,
}
}
void poet::ChemistryModule::copyPkgs(const WorkPackage &wp,
std::vector<double> &mpi_buffer,
std::size_t offset) {
std::vector<double> &mpi_buffer,
std::size_t offset) {
for (std::size_t wp_i = 0; wp_i < wp.size; wp_i++) {
std::copy(wp.output[wp_i].begin(), wp.output[wp_i].end(),
mpi_buffer.begin() + offset + this->prop_count * wp_i);
}
}
void poet::ChemistryModule::copyCtrlPkgs(const WorkPackage &pqc_wp,
const WorkPackage &surr_wp,
std::vector<double> &mpi_buffer,
int &count) {
const WorkPackage &surr_wp,
std::vector<double> &mpi_buffer, int &count) {
std::size_t wp_offset = surr_wp.size * this->prop_count;
mpi_buffer.resize(count + wp_offset);
@ -162,8 +160,7 @@ void poet::ChemistryModule::copyCtrlPkgs(const WorkPackage &pqc_wp,
count += wp_offset;
}
void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
int double_count,
void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, int double_count,
struct worker_s &timings) {
static int counter = 1;
@ -180,6 +177,9 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
int flags;
std::vector<double> mpi_buffer(count);
const int CL_INDEX = 7;
const double CL_THRESHOLD = 1e-10;
/* receive */
MPI_Recv(mpi_buffer.data(), count, MPI_DOUBLE, 0, LOOP_WORK, this->group_comm,
MPI_STATUS_IGNORE);
@ -216,6 +216,16 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
mpi_buffer.begin() + this->prop_count * (wp_i + 1));
}
/* skip simulation of cells cells where Cl concentration is below threshold */
/*
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
if (s_curr_wp.input[wp_i][CL_INDEX] < CL_THRESHOLD) {
s_curr_wp.mapping[wp_i] = CHEM_SKIP;
s_curr_wp.output[wp_i] = s_curr_wp.input[wp_i];
}
}
*/
// std::cout << this->comm_rank << ":" << counter++ << std::endl;
if (dht_enabled || interp_enabled || stab_enabled) {
dht->prepareKeys(s_curr_wp.input, dt);
@ -250,8 +260,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
if (ctrl_enabled) {
ctrl_cp_start = MPI_Wtime();
for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) {
s_curr_wp_control.output[wp_i] =
std::vector<double>(this->prop_count, 0.0);
s_curr_wp_control.output[wp_i] = std::vector<double>(this->prop_count, 0.0);
s_curr_wp_control.mapping[wp_i] = CHEM_PQC;
}
ctrl_cp_end = MPI_Wtime();
@ -260,8 +269,8 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
phreeqc_time_start = MPI_Wtime();
WorkerRunWorkPackage(ctrl_enabled ? s_curr_wp_control : s_curr_wp,
current_sim_time, dt);
WorkerRunWorkPackage(ctrl_enabled ? s_curr_wp_control : s_curr_wp, current_sim_time,
dt);
phreeqc_time_end = MPI_Wtime();
@ -278,8 +287,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
MPI_Request send_req;
int mpi_tag = ctrl_enabled ? LOOP_CTRL : LOOP_WORK;
MPI_Isend(mpi_buffer.data(), count, MPI_DOUBLE, 0, mpi_tag, MPI_COMM_WORLD,
&send_req);
MPI_Isend(mpi_buffer.data(), count, MPI_DOUBLE, 0, mpi_tag, MPI_COMM_WORLD, &send_req);
if (dht_enabled || interp_enabled || stab_enabled) {
/* write results to DHT */
@ -297,19 +305,18 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
MPI_Wait(&send_req, MPI_STATUS_IGNORE);
}
void poet::ChemistryModule::WorkerPostIter(MPI_Status &probe_status,
uint32_t iteration) {
void poet::ChemistryModule::WorkerPostIter(MPI_Status &probe_status, uint32_t iteration) {
int size, flush = 0;
MPI_Get_count(&probe_status, MPI_INT, &size);
if (size == 1) {
MPI_Recv(&flush, size, MPI_INT, probe_status.MPI_SOURCE, LOOP_END,
this->group_comm, MPI_STATUS_IGNORE);
MPI_Recv(&flush, size, MPI_INT, probe_status.MPI_SOURCE, LOOP_END, this->group_comm,
MPI_STATUS_IGNORE);
} else {
MPI_Recv(NULL, 0, MPI_INT, probe_status.MPI_SOURCE, LOOP_END,
this->group_comm, MPI_STATUS_IGNORE);
MPI_Recv(NULL, 0, MPI_INT, probe_status.MPI_SOURCE, LOOP_END, this->group_comm,
MPI_STATUS_IGNORE);
}
if (this->dht_enabled) {
@ -333,8 +340,7 @@ void poet::ChemistryModule::WorkerPostIter(MPI_Status &probe_status,
interp->dumpPHTState(out.str());
}
const auto max_mean_idx =
DHT_get_used_idx_factor(this->interp->getDHTObject(), 1);
const auto max_mean_idx = DHT_get_used_idx_factor(this->interp->getDHTObject(), 1);
if (max_mean_idx >= 2 || flush) {
DHT_flush(this->interp->getDHTObject());
@ -366,21 +372,17 @@ void poet::ChemistryModule::WorkerWriteDHTDump(uint32_t iteration) {
<< std::setw(this->file_pad) << iteration << ".dht";
int res = dht->tableToFile(out.str().c_str());
if (res != DHT_SUCCESS && this->comm_rank == 2)
std::cerr
<< "CPP: Worker: Error in writing current state of DHT to file.\n";
std::cerr << "CPP: Worker: Error in writing current state of DHT to file.\n";
else if (this->comm_rank == 2)
std::cout << "CPP: Worker: Successfully written DHT to file " << out.str()
<< "\n";
std::cout << "CPP: Worker: Successfully written DHT to file " << out.str() << "\n";
}
void poet::ChemistryModule::WorkerReadDHTDump(
const std::string &dht_input_file) {
void poet::ChemistryModule::WorkerReadDHTDump(const std::string &dht_input_file) {
int res = dht->fileToTable((char *)dht_input_file.c_str());
if (res != DHT_SUCCESS) {
if (res == DHT_WRONG_FILE) {
if (this->comm_rank == 1)
std::cerr
<< "CPP: Worker: Wrong file layout! Continue with empty DHT ...\n";
std::cerr << "CPP: Worker: Wrong file layout! Continue with empty DHT ...\n";
} else {
if (this->comm_rank == 1)
std::cerr << "CPP: Worker: Error in loading current state of DHT from "
@ -394,8 +396,7 @@ void poet::ChemistryModule::WorkerReadDHTDump(
}
void poet::ChemistryModule::WorkerRunWorkPackage(WorkPackage &work_package,
double dSimTime,
double dTimestep) {
double dSimTime, double dTimestep) {
std::vector<std::vector<double>> inout_chem = work_package.input;
std::vector<std::size_t> to_ignore;
@ -406,8 +407,7 @@ void poet::ChemistryModule::WorkerRunWorkPackage(WorkPackage &work_package,
}
// HACK: remove the first element (cell_id) before sending to phreeqc
inout_chem[wp_id].erase(inout_chem[wp_id].begin(),
inout_chem[wp_id].begin() + 1);
inout_chem[wp_id].erase(inout_chem[wp_id].begin(), inout_chem[wp_id].begin() + 1);
}
this->pqc_runner->run(inout_chem, dTimestep, to_ignore);
@ -423,8 +423,7 @@ void poet::ChemistryModule::WorkerRunWorkPackage(WorkPackage &work_package,
}
}
void poet::ChemistryModule::WorkerPerfToMaster(int type,
const struct worker_s &timings) {
void poet::ChemistryModule::WorkerPerfToMaster(int type, const struct worker_s &timings) {
switch (type) {
case WORKER_PHREEQC: {
MPI_Gather(&timings.phreeqc_t, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0,
@ -432,13 +431,11 @@ void poet::ChemistryModule::WorkerPerfToMaster(int type,
break;
}
case WORKER_CTRL_ITER: {
MPI_Gather(&timings.ctrl_t, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0,
this->group_comm);
MPI_Gather(&timings.ctrl_t, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0, this->group_comm);
break;
}
case WORKER_DHT_GET: {
MPI_Gather(&timings.dht_get, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0,
this->group_comm);
MPI_Gather(&timings.dht_get, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0, this->group_comm);
break;
}
case WORKER_DHT_FILL: {
@ -447,8 +444,7 @@ void poet::ChemistryModule::WorkerPerfToMaster(int type,
break;
}
case WORKER_IDLE: {
MPI_Gather(&timings.idle_t, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0,
this->group_comm);
MPI_Gather(&timings.idle_t, 1, MPI_DOUBLE, NULL, 1, MPI_DOUBLE, 0, this->group_comm);
break;
}
case WORKER_IP_WRITE: {
@ -484,15 +480,14 @@ void poet::ChemistryModule::WorkerMetricsToMaster(int type) {
MPI_Comm &group_comm = this->group_comm;
auto reduce_and_send = [&worker_rank, &worker_comm, &group_comm](
std::vector<std::uint32_t> &send_buffer, int tag) {
auto reduce_and_send = [&worker_rank, &worker_comm,
&group_comm](std::vector<std::uint32_t> &send_buffer, int tag) {
std::vector<uint32_t> to_master(send_buffer.size());
MPI_Reduce(send_buffer.data(), to_master.data(), send_buffer.size(),
MPI_UINT32_T, MPI_SUM, 0, worker_comm);
MPI_Reduce(send_buffer.data(), to_master.data(), send_buffer.size(), MPI_UINT32_T,
MPI_SUM, 0, worker_comm);
if (worker_rank == 0) {
MPI_Send(to_master.data(), to_master.size(), MPI_UINT32_T, 0, tag,
group_comm);
MPI_Send(to_master.data(), to_master.size(), MPI_UINT32_T, 0, tag, group_comm);
}
};

View File

@ -33,7 +33,7 @@ void poet::ControlModule::beginIteration(const uint32_t &iter, const bool &dht_e
void poet::ControlModule::updateSurrState(bool dht_enabled, bool interp_enabled) {
bool in_warmup = (global_iter <= config.ctrl_interval);
bool rb_limit_reached = (rb_count >= config.rb_limit);
bool rb_limit_reached = rbLimitReached();
if (rb_enabled && stab_countdown > 0 && !rb_limit_reached) {
--stab_countdown;
@ -54,9 +54,18 @@ void poet::ControlModule::updateSurrState(bool dht_enabled, bool interp_enabled)
} else {
std::cout << "In stabilization phase." << std::endl;
}
return;
}
if (rb_count > 0 && !rb_enabled && !in_warmup) {
surr_active++;
if (surr_active > config.rb_interval_limit) {
surr_active = 0;
rb_count -= 1;
std::cout << "Rollback count reset to: " << rb_count << "." << std::endl;
}
}
/* enable user-requested surrogates */
chem->SetStabEnabled(false);
chem->SetDhtEnabled(dht_enabled);
@ -80,7 +89,8 @@ void poet::ControlModule::readCheckpoint(uint32_t &current_iter, uint32_t rollba
double r_check_a, r_check_b;
r_check_a = MPI_Wtime();
Checkpoint_s checkpoint_read{.field = chem->getField()};
read_checkpoint(out_dir, "checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read);
read_checkpoint(out_dir, "checkpoint" + std::to_string(rollback_iter) + ".hdf5",
checkpoint_read);
current_iter = checkpoint_read.iteration;
r_check_b = MPI_Wtime();
r_check_t += r_check_b - r_check_a;
@ -102,20 +112,22 @@ void poet::ControlModule::writeMetrics(const std::string &out_dir,
uint32_t poet::ControlModule::calcRbIter() {
uint32_t last_iter = ((global_iter - 1) / config.chkpt_interval) * config.chkpt_interval;
uint32_t last_iter =
((global_iter - 1) / config.chkpt_interval) * config.chkpt_interval;
uint32_t rb_iter = (last_iter <= last_chkpt_written) ? last_iter : last_chkpt_written;
return rb_iter;
}
std::optional<uint32_t> poet::ControlModule::findRbTarget(const std::vector<std::string> &species) {
std::optional<uint32_t>
poet::ControlModule::findRbTarget(const std::vector<std::string> &species) {
if (metrics_history.empty()) {
std::cout << "No error history yet, skipping rollback check." << std::endl;
flush_request = false;
return std::nullopt;
}
if (rb_count > config.rb_limit) {
if (rbLimitReached()) {
std::cout << "Rollback limit reached, skipping control logic." << std::endl;
flush_request = false;
return std::nullopt;
@ -126,14 +138,19 @@ std::optional<uint32_t> poet::ControlModule::findRbTarget(const std::vector<std:
double r_check_a, r_check_b;
const auto &mape = metrics_history.back().mape;
for (uint32_t i = 0; i < species.size(); ++i) {
for (uint32_t sp_idx = 0; sp_idx < species.size(); ++sp_idx) {
if (mape[i] == 0) {
if (mape[sp_idx] == 0) {
continue;
}
if (mape[i] > config.mape_threshold[i]) {
std::cout << "Species " << species[i] << " MAPE=" << mape[i]
<< " threshold=" << config.mape_threshold[i] << std::endl;
/* skip Charge */
if (sp_idx == 4) {
continue;
}
if (mape[sp_idx] > config.mape_threshold[sp_idx]) {
std::cout << "Species " << species[sp_idx] << " MAPE=" << mape[sp_idx]
<< " threshold=" << config.mape_threshold[sp_idx] << std::endl;
if (last_chkpt_written == 0) {
std::cout << " Threshold exceeded but no checkpoint exists yet." << std::endl;
@ -141,9 +158,10 @@ std::optional<uint32_t> poet::ControlModule::findRbTarget(const std::vector<std:
}
// rb_enabled = true;
flush_request = true;
std::cout << "Threshold exceeded " << species[i] << " has MAPE = " << std::to_string(mape[i])
<< " exceeding threshold = " << std::to_string(config.mape_threshold[i])
<< std::endl;
std::cout << "Threshold exceeded " << species[sp_idx]
<< " has MAPE = " << std::to_string(mape[sp_idx])
<< " exceeding threshold = "
<< std::to_string(config.mape_threshold[sp_idx]) << std::endl;
return calcRbIter();
}
}
@ -157,16 +175,16 @@ void poet::ControlModule::computeMetrics(const std::vector<double> &reference_va
const uint32_t size_per_prop,
const std::vector<std::string> &species) {
if (rb_count > config.rb_limit) {
if (rbLimitReached()) {
return;
}
SpeciesMetrics metrics(species.size(), global_iter, rb_count);
for (uint32_t i = 0; i < species.size(); ++i) {
for (uint32_t sp_idx = 0; sp_idx < species.size(); ++sp_idx) {
double err_sum = 0.0;
double sqr_err_sum = 0.0;
uint32_t base_idx = i * size_per_prop;
uint32_t base_idx = sp_idx * size_per_prop;
for (uint32_t j = 0; j < size_per_prop; ++j) {
const double ref_value = reference_values[base_idx + j];
@ -187,16 +205,17 @@ void poet::ControlModule::computeMetrics(const std::vector<double> &reference_va
sqr_err_sum += alpha * alpha;
}
}
metrics.mape[i] = 100.0 * (err_sum / size_per_prop);
metrics.rrmse[i] = std::sqrt(sqr_err_sum / size_per_prop);
metrics.mape[sp_idx] = 100.0 * (err_sum / size_per_prop);
metrics.rrmse[sp_idx] = std::sqrt(sqr_err_sum / size_per_prop);
}
metrics_history.push_back(metrics);
}
void poet::ControlModule::processCheckpoint(uint32_t &current_iter, const std::string &out_dir,
void poet::ControlModule::processCheckpoint(uint32_t &current_iter,
const std::string &out_dir,
const std::vector<std::string> &species) {
if (!ctrl_active || rb_count > config.rb_limit) {
if (!ctrl_active || rbLimitReached()) {
return;
}
@ -208,19 +227,20 @@ void poet::ControlModule::processCheckpoint(uint32_t &current_iter, const std::s
rb_count++;
stab_countdown = config.ctrl_interval;
std::cout << "Restored checkpoint " << std::to_string(target) << ", surrogates disabled for "
<< config.ctrl_interval << std::endl;
std::cout << "Restored checkpoint " << std::to_string(target)
<< ", surrogates disabled for " << config.ctrl_interval << std::endl;
} else {
writeCheckpoint(global_iter, out_dir);
}
}
bool poet::ControlModule::needsFlagBcast() const {
if (rb_count > config.rb_limit) {
return false;
}
if (global_iter == 1 || global_iter % config.ctrl_interval == 1) {
return true;
}
return false;
return (config.rb_limit > 0) && !rbLimitReached();
}
inline bool poet::ControlModule::rbLimitReached() const {
/* rollback is completly disabled */
if (config.rb_limit == 0)
return false;
return rb_count >= config.rb_limit;
}

View File

@ -17,6 +17,7 @@ struct ControlConfig {
uint32_t ctrl_interval = 0;
uint32_t chkpt_interval = 0;
uint32_t rb_limit = 0;
uint32_t rb_interval_limit = 0;
double zero_abs = 0.0;
std::vector<double> mape_threshold;
};
@ -28,8 +29,7 @@ struct SpeciesMetrics {
uint32_t rb_count = 0;
SpeciesMetrics(uint32_t n_species, uint32_t iter, uint32_t count)
: mape(n_species, 0.0), rrmse(n_species, 0.0), iteration(iter),
rb_count(count) {}
: mape(n_species, 0.0), rrmse(n_species, 0.0), iteration(iter), rb_count(count) {}
};
class ControlModule {
@ -41,27 +41,22 @@ public:
void writeCheckpoint(uint32_t &iter, const std::string &out_dir);
void writeMetrics(const std::string &out_dir,
const std::vector<std::string> &species);
void writeMetrics(const std::string &out_dir, const std::vector<std::string> &species);
std::optional<uint32_t> findRbTarget();
void computeMetrics(const std::vector<double> &reference_values,
const std::vector<double> &surrogate_values,
const uint32_t size_per_prop,
const std::vector<std::string> &species);
const std::vector<double> &surrogate_values,
const uint32_t size_per_prop,
const std::vector<std::string> &species);
void processCheckpoint(uint32_t &current_iter,
const std::string &out_dir,
void processCheckpoint(uint32_t &current_iter, const std::string &out_dir,
const std::vector<std::string> &species);
std::optional<uint32_t>
findRbTarget(const std::vector<std::string> &species);
std::optional<uint32_t> findRbTarget(const std::vector<std::string> &species);
bool needsFlagBcast() const;
bool isCtrlIntervalActive() const {
return this->ctrl_active;
}
bool isCtrlIntervalActive() const { return this->ctrl_active; }
bool getFlushRequest() const { return flush_request; }
void clearFlushRequest() { flush_request = false; }
@ -76,17 +71,20 @@ public:
private:
void updateSurrState(bool dht_enabled, bool interp_enabled);
void readCheckpoint(uint32_t &current_iter,
uint32_t rollback_iter, const std::string &out_dir);
void readCheckpoint(uint32_t &current_iter, uint32_t rollback_iter,
const std::string &out_dir);
uint32_t calcRbIter();
inline bool rbLimitReached() const;
ControlConfig config;
ChemistryModule *chem = nullptr;
std::uint32_t global_iter = 0;
std::uint32_t rb_count = 0;
std::uint32_t stab_countdown = 0;
std::uint32_t surr_active = 0;
std::uint32_t last_chkpt_written = 0;
bool rb_enabled = false;

View File

@ -99,9 +99,8 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
"Print progress bar during chemical simulation");
/*Parse work package size*/
app.add_option(
"-w,--work-package-size", params.work_package_size,
"Work package size to distribute to each worker for chemistry module")
app.add_option("-w,--work-package-size", params.work_package_size,
"Work package size to distribute to each worker for chemistry module")
->check(CLI::PositiveNumber)
->default_val(RuntimeParameters::WORK_PACKAGE_SIZE_DEFAULT);
@ -112,9 +111,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
// cout << "CPP: DHT is " << ( dht_enabled ? "ON" : "OFF" ) << '\n';
dht_group
->add_option("--dht-size", params.dht_size,
"DHT size per process in Megabyte")
dht_group->add_option("--dht-size", params.dht_size, "DHT size per process in Megabyte")
->check(CLI::PositiveNumber)
->default_val(RuntimeParameters::DHT_SIZE_DEFAULT);
// cout << "CPP: DHT size per process (Byte) = " << dht_size_per_process <<
@ -140,9 +137,8 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
->check(CLI::PositiveNumber)
->default_val(RuntimeParameters::INTERP_MIN_ENTRIES_DEFAULT);
interp_group
->add_option(
"--interp-bucket-entries", params.interp_bucket_entries,
"Maximum number of entries in each bucket of the interpolation table")
->add_option("--interp-bucket-entries", params.interp_bucket_entries,
"Maximum number of entries in each bucket of the interpolation table")
->check(CLI::PositiveNumber)
->default_val(RuntimeParameters::INTERP_BUCKET_ENTRIES_DEFAULT);
@ -152,25 +148,21 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
app.add_flag("--rds", params.as_rds,
"Save output as .rds file instead of default .qs2");
app.add_flag("--qs", params.as_qs,
"Save output as .qs file instead of default .qs2");
app.add_flag("--qs", params.as_qs, "Save output as .qs file instead of default .qs2");
std::string init_file;
std::string runtime_file;
app.add_option("runtime_file", runtime_file,
"Runtime R script defining the simulation")
app.add_option("runtime_file", runtime_file, "Runtime R script defining the simulation")
->required()
->check(CLI::ExistingFile);
app.add_option(
"init_file", init_file,
"Initial R script defining the simulation, produced by poet_init")
app.add_option("init_file", init_file,
"Initial R script defining the simulation, produced by poet_init")
->required()
->check(CLI::ExistingFile);
app.add_option("out_dir", params.out_dir,
"Output directory of the simulation")
app.add_option("out_dir", params.out_dir, "Output directory of the simulation")
->required();
try {
@ -202,8 +194,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
// << simparams.dht_significant_digits);
// MSG("DHT logarithm before rounding: "
// << (simparams.dht_log ? "ON" : "OFF"));
MSG("DHT size per process (Megabyte) = " +
std::to_string(params.dht_size));
MSG("DHT size per process (Megabyte) = " + std::to_string(params.dht_size));
MSG("DHT save snapshots is " + BOOL_PRINT(params.dht_snaps));
// MSG("DHT load file is " + chem_params.dht_file);
}
@ -212,8 +203,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
MSG("PHT interpolation enabled: " + BOOL_PRINT(params.use_interp));
MSG("PHT interp-size = " + std::to_string(params.interp_size));
MSG("PHT interp-min = " + std::to_string(params.interp_min_entries));
MSG("PHT interp-bucket-entries = " +
std::to_string(params.interp_bucket_entries));
MSG("PHT interp-bucket-entries = " + std::to_string(params.interp_bucket_entries));
}
}
// chem_params.dht_outdir = out_dir;
@ -253,10 +243,11 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
Rcpp::as<uint32_t>(global_rt_setup->operator[]("ctrl_interval"));
params.chkpt_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("chkpt_interval"));
params.rb_limit =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("rb_limit"));
params.mape_threshold = Rcpp::as<std::vector<double>>(
global_rt_setup->operator[]("mape_threshold"));
params.rb_limit = Rcpp::as<uint32_t>(global_rt_setup->operator[]("rb_limit"));
params.rb_interval_limit =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("rb_interval_limit"));
params.mape_threshold =
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("mape_threshold"));
params.zero_abs = Rcpp::as<double>(global_rt_setup->operator[]("zero_abs"));
} catch (const std::exception &e) {
ERRMSG("Error while parsing R scripts: " + std::string(e.what()));
@ -278,16 +269,15 @@ void call_master_iter_end(RInside &R, const Field &trans, const Field &chem) {
R["TMP"] = Rcpp::wrap(chem.AsVector());
R["TMP_PROPS"] = Rcpp::wrap(chem.GetProps());
R.parseEval(std::string("state_C <- setNames(data.frame(matrix(TMP, nrow=" +
std::to_string(chem.GetRequestedVecSize()) +
")), TMP_PROPS)"));
std::to_string(chem.GetRequestedVecSize()) + ")), TMP_PROPS)"));
R["setup"] = *global_rt_setup;
R.parseEval("setup <- master_iteration_end(setup, state_T, state_C)");
*global_rt_setup = R["setup"];
}
static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
DiffusionModule &diffusion,
ChemistryModule &chem, ControlModule &control) {
DiffusionModule &diffusion, ChemistryModule &chem,
ControlModule &control) {
/* Iteration Count is dynamic, retrieving value from R (is only needed by
* master for the following loop) */
@ -327,10 +317,9 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
double ai_start_t = MPI_Wtime();
// Save current values from the tug field as predictor for the ai step
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
R.parseEval(
std::string("predictors <- setNames(data.frame(matrix(TMP, nrow=" +
std::to_string(chem.getField().GetRequestedVecSize()) +
")), TMP_PROPS)"));
R.parseEval(std::string("predictors <- setNames(data.frame(matrix(TMP, nrow=" +
std::to_string(chem.getField().GetRequestedVecSize()) +
")), TMP_PROPS)"));
R.parseEval("predictors <- predictors[ai_surrogate_species]");
// Apply preprocessing
@ -339,8 +328,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
// Predict
MSG("AI Prediction");
R.parseEval(
"aipreds_scaled <- prediction_step(model, predictors_scaled)");
R.parseEval("aipreds_scaled <- prediction_step(model, predictors_scaled)");
// Apply postprocessing
MSG("AI Postprocessing");
@ -348,8 +336,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
// Validate prediction and write valid predictions to chem field
MSG("AI Validation");
R.parseEval(
"validity_vector <- validate_predictions(predictors, aipreds)");
R.parseEval("validity_vector <- validate_predictions(predictors, aipreds)");
MSG("AI Marking accepted");
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
@ -361,9 +348,8 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
validity_vector)");
MSG("AI Set Field");
Field predictions_field =
Field(R.parseEval("nrow(predictors)"), RTempField,
R.parseEval("colnames(predictors)"));
Field predictions_field = Field(R.parseEval("nrow(predictors)"), RTempField,
R.parseEval("colnames(predictors)"));
MSG("AI Update");
chem.getField().update(predictions_field);
@ -378,10 +364,9 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
double ai_start_t = MPI_Wtime();
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
R.parseEval(
std::string("targets <- setNames(data.frame(matrix(TMP, nrow=" +
std::to_string(chem.getField().GetRequestedVecSize()) +
")), TMP_PROPS)"));
R.parseEval(std::string("targets <- setNames(data.frame(matrix(TMP, nrow=" +
std::to_string(chem.getField().GetRequestedVecSize()) +
")), TMP_PROPS)"));
R.parseEval("targets <- targets[ai_surrogate_species]");
// TODO: Check how to get the correct columns
@ -414,8 +399,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
std::to_string(maxiter));
if (control.isCtrlIntervalActive()) {
control.processCheckpoint(iter, params.out_dir,
chem.getField().GetProps());
control.processCheckpoint(iter, params.out_dir, chem.getField().GetProps());
control.writeMetrics(params.out_dir, chem.getField().GetProps());
}
// MSG();
@ -452,16 +436,12 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
//}
if (params.use_interp) {
chem_profiling["interp_w"] =
Rcpp::wrap(chem.GetWorkerInterpolationWriteTimings());
chem_profiling["interp_r"] =
Rcpp::wrap(chem.GetWorkerInterpolationReadTimings());
chem_profiling["interp_g"] =
Rcpp::wrap(chem.GetWorkerInterpolationGatherTimings());
chem_profiling["interp_w"] = Rcpp::wrap(chem.GetWorkerInterpolationWriteTimings());
chem_profiling["interp_r"] = Rcpp::wrap(chem.GetWorkerInterpolationReadTimings());
chem_profiling["interp_g"] = Rcpp::wrap(chem.GetWorkerInterpolationGatherTimings());
chem_profiling["interp_fc"] =
Rcpp::wrap(chem.GetWorkerInterpolationFunctionCallTimings());
chem_profiling["interp_calls"] =
Rcpp::wrap(chem.GetWorkerInterpolationCalls());
chem_profiling["interp_calls"] = Rcpp::wrap(chem.GetWorkerInterpolationCalls());
chem_profiling["interp_cached"] = Rcpp::wrap(chem.GetWorkerPHTCacheHits());
}
@ -476,8 +456,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
return profiling;
}
std::vector<std::string> getSpeciesNames(const Field &&field, int root,
MPI_Comm comm) {
std::vector<std::string> getSpeciesNames(const Field &&field, int root, MPI_Comm comm) {
std::uint32_t n_elements;
std::uint32_t n_string_size;
@ -494,8 +473,8 @@ std::vector<std::string> getSpeciesNames(const Field &&field, int root,
for (std::uint32_t i = 0; i < n_elements; i++) {
n_string_size = field.GetProps()[i].size();
MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
MPI_Bcast(const_cast<char *>(field.GetProps()[i].c_str()), n_string_size,
MPI_CHAR, root, MPI_COMM_WORLD);
MPI_Bcast(const_cast<char *>(field.GetProps()[i].c_str()), n_string_size, MPI_CHAR,
root, MPI_COMM_WORLD);
}
return field.GetProps();
@ -609,8 +588,8 @@ int main(int argc, char *argv[]) {
MPI_Barrier(MPI_COMM_WORLD);
ChemistryModule chemistry(run_params.work_package_size,
init_list.getChemistryInit(), MPI_COMM_WORLD);
ChemistryModule chemistry(run_params.work_package_size, init_list.getChemistryInit(),
MPI_COMM_WORLD);
// ControlModule control;
// chemistry.SetControlModule(&control);
@ -633,8 +612,8 @@ int main(int argc, char *argv[]) {
chemistry.masterEnableSurrogates(surr_setup);
ControlConfig config(run_params.ctrl_interval, run_params.chkpt_interval,
run_params.rb_limit, run_params.zero_abs,
run_params.mape_threshold);
run_params.rb_limit, run_params.rb_interval_limit,
run_params.zero_abs, run_params.mape_threshold);
ControlModule control(config, &chemistry);
@ -660,8 +639,7 @@ int main(int argc, char *argv[]) {
/* Incorporate ai surrogate from R */
R.parseEvalQ(ai_surrogate_r_library);
/* Use dht species for model input and output */
R["ai_surrogate_species"] =
init_list.getChemistryInit().dht_species.getNames();
R["ai_surrogate_species"] = init_list.getChemistryInit().dht_species.getNames();
const std::string ai_surrogate_input_script =
init_list.getChemistryInit().ai_surrogate_input_script;
@ -678,13 +656,11 @@ int main(int argc, char *argv[]) {
// MPI_Barrier(MPI_COMM_WORLD);
DiffusionModule diffusion(init_list.getDiffusionInit(),
init_list.getInitialGrid());
DiffusionModule diffusion(init_list.getDiffusionInit(), init_list.getInitialGrid());
chemistry.masterSetField(init_list.getInitialGrid());
Rcpp::List profiling =
RunMasterLoop(R, run_params, diffusion, chemistry, control);
Rcpp::List profiling = RunMasterLoop(R, run_params, diffusion, chemistry, control);
MSG("finished simulation loop");

View File

@ -54,6 +54,7 @@ struct RuntimeParameters {
std::uint32_t chkpt_interval = 0;
std::uint32_t ctrl_interval = 0;
std::uint32_t rb_limit = 0;
std::uint32_t rb_interval_limit = 0;
std::vector<double> mape_threshold;
double zero_abs = 0.0;