mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-15 12:28:22 +01:00
Compare commits
2 Commits
92e4414bfa
...
ac6bff6b97
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ac6bff6b97 | ||
|
|
fddef8d01d |
@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=proto2_eps01_3_rb
|
||||
#SBATCH --output=proto2_eps01_3_rb_%j.out
|
||||
#SBATCH --error=proto2_eps01_3_rb_%j.err
|
||||
#SBATCH --job-name=proto2_eps0035_no_rb_v2
|
||||
#SBATCH --output=proto2_eps0035_no_rb_v2_%j.out
|
||||
#SBATCH --error=proto2_eps0035_no_rb_v2%j.err
|
||||
#SBATCH --partition=long
|
||||
#SBATCH --nodes=6
|
||||
#SBATCH --ntasks-per-node=24
|
||||
@ -15,5 +15,5 @@ module purge
|
||||
module load cmake gcc openmpi
|
||||
|
||||
#mpirun -n 144 ./poet dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_only_pqc
|
||||
mpirun -n 144 ./poet --interp dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 proto2_eps01_3_rb
|
||||
mpirun -n 144 ./poet --interp --rds dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 proto2_eps0035_no_rb_v2
|
||||
#mpirun -n 144 ./poet --interp barite_fgcs_4_new/barite_fgcs_4_new_rt.R barite_fgcs_4_new/barite_fgcs_4_new.qs2 barite
|
||||
Binary file not shown.
Binary file not shown.
@ -473,7 +473,7 @@ protected:
|
||||
bool control_enabled{false};
|
||||
bool stab_enabled{false};
|
||||
std::unordered_set<uint32_t> ctrl_cell_ids;
|
||||
std::vector<std::vector<double>> control_batch;
|
||||
std::vector<std::vector<double>> ctrl_batch;
|
||||
};
|
||||
} // namespace poet
|
||||
|
||||
|
||||
@ -368,7 +368,7 @@ inline void poet::ChemistryModule::MasterRecvPkgs(worker_list_t &w_list,
|
||||
std::vector<double> cell_output(
|
||||
recv_buffer.begin() + this->prop_count * i,
|
||||
recv_buffer.begin() + this->prop_count * (i + 1));
|
||||
this->control_batch.push_back(std::move(cell_output));
|
||||
this->ctrl_batch.push_back(std::move(cell_output));
|
||||
}
|
||||
break;
|
||||
}
|
||||
@ -443,13 +443,11 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
||||
MPI_INT);
|
||||
}
|
||||
|
||||
// if (control->shouldBcastFlags()) {
|
||||
ftype = CHEM_CTRL_FLAGS;
|
||||
PropagateFunctionType(ftype);
|
||||
uint32_t ctrl_flags = buildCtrlFlags(this->dht_enabled, this->interp_enabled,
|
||||
this->stab_enabled);
|
||||
ChemBCast(&ctrl_flags, 1, MPI_INT);
|
||||
//}
|
||||
|
||||
ftype = CHEM_WORK_LOOP;
|
||||
PropagateFunctionType(ftype);
|
||||
@ -467,8 +465,6 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
||||
shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count,
|
||||
wp_sizes_vector.size());
|
||||
|
||||
// this->mpi_surr_buffer.resize(mpi_buffer.size());
|
||||
|
||||
/* setup local variables */
|
||||
pkg_to_send = wp_sizes_vector.size();
|
||||
pkg_to_recv = wp_sizes_vector.size();
|
||||
@ -523,39 +519,37 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
||||
|
||||
/* do master stuff */
|
||||
|
||||
if (!this->control_batch.empty()) {
|
||||
std::cout << "[Master] Processing " << this->control_batch.size()
|
||||
if (!this->ctrl_batch.empty()) {
|
||||
std::cout << "[Master] Processing " << this->ctrl_batch.size()
|
||||
<< " control cells for comparison." << std::endl;
|
||||
|
||||
|
||||
std::vector<std::vector<double>> sur_batch;
|
||||
sur_batch.reserve(this->ctrl_batch.size());
|
||||
|
||||
for (const auto &element : this->ctrl_batch) {
|
||||
|
||||
/* using mpi-buffer because we need cell-major layout*/
|
||||
std::vector<std::vector<double>> surrogate_batch;
|
||||
surrogate_batch.reserve(this->control_batch.size());
|
||||
|
||||
for (const auto &element : this->control_batch) {
|
||||
|
||||
for (size_t i = 0; i < this->n_cells; i++) {
|
||||
uint32_t curr_cell_id = mpi_buffer[this->prop_count * i];
|
||||
|
||||
if (curr_cell_id == element[0]) {
|
||||
std::vector<double> surrogate_output(
|
||||
std::vector<double> sur_output(
|
||||
mpi_buffer.begin() + this->prop_count * i,
|
||||
mpi_buffer.begin() + this->prop_count * (i + 1));
|
||||
surrogate_batch.push_back(surrogate_output);
|
||||
sur_batch.push_back(sur_output);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
metrics_a = MPI_Wtime();
|
||||
control->computeErrorMetrics(this->control_batch, surrogate_batch,
|
||||
control->computeErrorMetrics(this->ctrl_batch, sur_batch,
|
||||
prop_names, n_cells);
|
||||
control->writeErrorMetrics(ctrl_file_out_dir, prop_names);
|
||||
|
||||
metrics_b = MPI_Wtime();
|
||||
this->metrics_t += metrics_b - metrics_a;
|
||||
|
||||
// Clear for next control iteration
|
||||
this->control_batch.clear();
|
||||
this->ctrl_batch.clear();
|
||||
}
|
||||
|
||||
/* start time measurement of master chemistry */
|
||||
|
||||
@ -7,23 +7,21 @@
|
||||
poet::ControlModule::ControlModule(const ControlConfig &config_)
|
||||
: config(config_) {}
|
||||
|
||||
void poet::ControlModule::beginIteration(ChemistryModule &chem,
|
||||
const uint32_t &iter,
|
||||
void poet::ControlModule::beginIteration(ChemistryModule &chem, uint32_t &iter,
|
||||
const bool &dht_enabled,
|
||||
const bool &interp_enabled) {
|
||||
|
||||
/* dht_enabled and inter_enabled are user settings set before startig the
|
||||
* simulation*/
|
||||
global_iteration = iter;
|
||||
double prep_a, prep_b;
|
||||
|
||||
prep_a = MPI_Wtime();
|
||||
|
||||
global_iteration = iter;
|
||||
updateStabilizationPhase(chem, dht_enabled, interp_enabled);
|
||||
prep_b = MPI_Wtime();
|
||||
this->prep_t += prep_b - prep_a;
|
||||
}
|
||||
|
||||
/* Disables dht and/or interp during stabilzation phase */
|
||||
void poet::ControlModule::updateStabilizationPhase(ChemistryModule &chem,
|
||||
bool dht_enabled,
|
||||
bool interp_enabled) {
|
||||
@ -31,16 +29,15 @@ void poet::ControlModule::updateStabilizationPhase(ChemistryModule &chem,
|
||||
if (disable_surr_counter > 0) {
|
||||
--disable_surr_counter;
|
||||
flush_request = false;
|
||||
MSG("Rollback counter: " + std::to_string(disable_surr_counter));
|
||||
std::cout << "Rollback counter: " << std::to_string(disable_surr_counter)
|
||||
<< std::endl;
|
||||
} else {
|
||||
rollback_enabled = false;
|
||||
MSG("Rollback stabilization complete, re-enabling surrogates");
|
||||
std::cout << "Rollback stabilization complete, re-enabling surrogate."
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
bool prev_stab_state = chem.GetStabEnabled();
|
||||
|
||||
// user requested DHT/INTEP? keep them disabled but enable warmup-phase so
|
||||
if (global_iteration <= config.stab_interval || rollback_enabled) {
|
||||
chem.SetStabEnabled(true);
|
||||
chem.SetDhtEnabled(false);
|
||||
@ -50,28 +47,20 @@ void poet::ControlModule::updateStabilizationPhase(ChemistryModule &chem,
|
||||
chem.SetDhtEnabled(dht_enabled);
|
||||
chem.SetInterpEnabled(interp_enabled);
|
||||
}
|
||||
|
||||
// Mark that we need to broadcast flags if stab state changed
|
||||
if (prev_stab_state != chem.GetStabEnabled()) {
|
||||
stab_phase_ended = true;
|
||||
}
|
||||
}
|
||||
|
||||
void poet::ControlModule::writeCheckpoint(DiffusionModule &diffusion,
|
||||
uint32_t &iter,
|
||||
const std::string &out_dir) {
|
||||
double w_check_a, w_check_b;
|
||||
|
||||
w_check_a = MPI_Wtime();
|
||||
if (global_iteration % config.checkpoint_interval == 0) {
|
||||
double w_check_a = MPI_Wtime();
|
||||
write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
|
||||
{.field = diffusion.getField(), .iteration = iter});
|
||||
}
|
||||
w_check_b = MPI_Wtime();
|
||||
double w_check_b = MPI_Wtime();
|
||||
this->w_check_t += w_check_b - w_check_a;
|
||||
|
||||
last_checkpoint_written = iter;
|
||||
}
|
||||
}
|
||||
|
||||
void poet::ControlModule::readCheckpoint(DiffusionModule &diffusion,
|
||||
uint32_t ¤t_iter,
|
||||
@ -90,18 +79,15 @@ void poet::ControlModule::readCheckpoint(DiffusionModule &diffusion,
|
||||
}
|
||||
|
||||
void poet::ControlModule::writeErrorMetrics(
|
||||
const std::string &out_dir, const std::vector<std::string> &species) {
|
||||
double stats_a, stats_b;
|
||||
uint32_t &iter, const std::string &out_dir,
|
||||
const std::vector<std::string> &species) {
|
||||
|
||||
stats_a = MPI_Wtime();
|
||||
double stats_a = MPI_Wtime();
|
||||
writeSpeciesStatsToCSV(metrics_history, species, out_dir,
|
||||
"species_overview.csv");
|
||||
// writeCellStatsToCSV(cell_metrics_history, species, out_dir,
|
||||
// "cell_overview.csv");
|
||||
write_metrics(cell_metrics_history, species, out_dir,
|
||||
"metrics_overview.hdf5");
|
||||
stats_b = MPI_Wtime();
|
||||
|
||||
double stats_b = MPI_Wtime();
|
||||
this->stats_t += stats_b - stats_a;
|
||||
}
|
||||
|
||||
@ -109,19 +95,6 @@ uint32_t poet::ControlModule::getRollbackIter() {
|
||||
|
||||
uint32_t last_iter = ((global_iteration - 1) / config.checkpoint_interval) *
|
||||
config.checkpoint_interval;
|
||||
|
||||
/*
|
||||
uint32_t rollback_iter = (last_iter <= last_checkpoint_written)
|
||||
? last_iter
|
||||
: last_checkpoint_written;
|
||||
|
||||
MSG("getRollbackIter: global_iteration=" + std::to_string(global_iteration) +
|
||||
", checkpoint_interval=" + std::to_string(config.checkpoint_interval) +
|
||||
", last_checkpoint_written=" + std::to_string(last_checkpoint_written) +
|
||||
", returning=" + std::to_string(last_checkpoint_written));
|
||||
|
||||
return last_checkpoint_written;
|
||||
*/
|
||||
return last_iter;
|
||||
}
|
||||
|
||||
@ -129,55 +102,47 @@ std::optional<uint32_t> poet::ControlModule::getRollbackTarget(
|
||||
const std::vector<std::string> &species) {
|
||||
double r_check_a, r_check_b;
|
||||
|
||||
if (metrics_history.empty()) {
|
||||
MSG("No error history yet, skipping rollback check.");
|
||||
rollback_enabled = false;
|
||||
return std::nullopt;
|
||||
}
|
||||
// Skip threshold checking if already in rollback/stabilization phase
|
||||
if (rollback_enabled) {
|
||||
/* Skip threshold checking if already in stabilization phase*/
|
||||
if (metrics_history.empty() || rollback_enabled) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
const auto &s_mape = metrics_history.back().mape;
|
||||
const auto &s_hist = metrics_history.back();
|
||||
|
||||
/* skipping cell_id and id */
|
||||
for (size_t sp_i = 2; sp_i < species.size(); sp_i++) {
|
||||
|
||||
// skip charge
|
||||
if (s_mape[sp_i] == 0 || sp_i == 4) {
|
||||
/* check bounds of threshold vector*/
|
||||
if (sp_i >= config.mape_threshold.size()) {
|
||||
std::cerr << "Warning: No threshold defined for species " << species[sp_i]
|
||||
<< " at index " << std::to_string(sp_i) << std::endl;
|
||||
continue;
|
||||
}
|
||||
if (s_mape[sp_i] > config.mape_threshold[sp_i]) {
|
||||
if (last_checkpoint_written == 0) {
|
||||
MSG(" Threshold exceeded but no checkpoint exists yet.");
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
const auto &c_mape = cell_metrics_history.back().mape;
|
||||
const auto &c_id = cell_metrics_history.back().id;
|
||||
if (s_hist.mape[sp_i] > config.mape_threshold[sp_i]) {
|
||||
|
||||
const auto &c_hist = cell_metrics_history.back();
|
||||
auto max_it = std::max_element(
|
||||
c_mape.begin(), c_mape.end(),
|
||||
c_hist.mape.begin(), c_hist.mape.end(),
|
||||
[sp_i](const auto &a, const auto &b) { return a[sp_i] < b[sp_i]; });
|
||||
|
||||
size_t max_idx = std::distance(c_mape.begin(), max_it);
|
||||
uint32_t cell_id = c_id[max_idx];
|
||||
size_t max_idx = std::distance(c_hist.mape.begin(), max_it);
|
||||
uint32_t cell_id = c_hist.id[max_idx];
|
||||
double cell_mape = (*max_it)[sp_i];
|
||||
|
||||
rollback_enabled = true;
|
||||
flush_request = true;
|
||||
|
||||
MSG("Threshold exceeded for " + species[sp_i] +
|
||||
" with species-level MAPE = " + std::to_string(s_mape[sp_i]) +
|
||||
" exceeding threshold = " +
|
||||
std::to_string(config.mape_threshold[sp_i]) + ". Worst cell: ID=" +
|
||||
std::to_string(cell_id) + " with MAPE=" + std::to_string(cell_mape));
|
||||
std::cout << "Threshold exceeded for " << species[sp_i]
|
||||
<< " with species-level MAPE = " << std::to_string(s_hist.mape[sp_i])
|
||||
<< " exceeding threshold = "
|
||||
<< std::to_string(config.mape_threshold[sp_i])
|
||||
<< ". Worst cell: ID=" << std::to_string(cell_id)
|
||||
<< ", species= " << species[sp_i]
|
||||
<< " with MAPE=" << std::to_string(cell_mape) << std::endl;
|
||||
|
||||
return getRollbackIter();
|
||||
}
|
||||
}
|
||||
rollback_enabled = false;
|
||||
flush_request = false;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
@ -186,11 +151,6 @@ void poet::ControlModule::computeErrorMetrics(
|
||||
std::vector<std::vector<double>> &surrogate_values,
|
||||
const std::vector<std::string> &species, const uint32_t size_per_prop) {
|
||||
|
||||
// Skip metric computation if already in rollback/stabilization phase
|
||||
if (rollback_enabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t n_cells = reference_values.size();
|
||||
const uint32_t n_species = species.size();
|
||||
const double ZERO_ABS = config.zero_abs;
|
||||
@ -256,7 +216,11 @@ void poet::ControlModule::processCheckpoint(
|
||||
DiffusionModule &diffusion, uint32_t ¤t_iter,
|
||||
const std::string &out_dir, const std::vector<std::string> &species) {
|
||||
|
||||
if (flush_request && rollback_count < 3) {
|
||||
// Use max_rollbacks from config, default to 3 if not set
|
||||
// uint32_t max_rollbacks =
|
||||
// (config.max_rollbacks > 0) ? config.max_rollbacks : 3;
|
||||
|
||||
if (flush_request /* && rollback_count < 3 */) {
|
||||
uint32_t target = getRollbackIter();
|
||||
readCheckpoint(diffusion, current_iter, target, out_dir);
|
||||
|
||||
@ -264,21 +228,10 @@ void poet::ControlModule::processCheckpoint(
|
||||
rollback_count++;
|
||||
disable_surr_counter = config.stab_interval;
|
||||
|
||||
MSG("Restored checkpoint " + std::to_string(target) +
|
||||
", surrogates disabled for " + std::to_string(config.stab_interval));
|
||||
std::cout << "Restored checkpoint " << std::to_string(target)
|
||||
<< ", surrogate disabled for "
|
||||
<< std::to_string(config.stab_interval) << std::endl;
|
||||
} else {
|
||||
writeCheckpoint(diffusion, global_iteration, out_dir);
|
||||
}
|
||||
}
|
||||
|
||||
bool poet::ControlModule::shouldBcastFlags() {
|
||||
if (global_iteration == 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (stab_phase_ended) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
@ -19,7 +19,8 @@ class DiffusionModule;
|
||||
|
||||
struct ControlConfig {
|
||||
uint32_t stab_interval = 0;
|
||||
uint32_t checkpoint_interval = 0;
|
||||
uint32_t checkpoint_interval = 0; // How often to write metrics files
|
||||
//uint32_t max_rb = 0; // Maximum number of rollbacks allowed
|
||||
double zero_abs = 0.0;
|
||||
std::vector<double> mape_threshold;
|
||||
};
|
||||
@ -54,10 +55,10 @@ class ControlModule {
|
||||
public:
|
||||
explicit ControlModule(const ControlConfig &config);
|
||||
|
||||
void beginIteration(ChemistryModule &chem, const uint32_t &iter,
|
||||
void beginIteration(ChemistryModule &chem, uint32_t &iter,
|
||||
const bool &dht_enabled, const bool &interp_enaled);
|
||||
|
||||
void writeErrorMetrics(const std::string &out_dir,
|
||||
void writeErrorMetrics(uint32_t &iter, const std::string &out_dir,
|
||||
const std::vector<std::string> &species);
|
||||
|
||||
std::optional<uint32_t> getRollbackTarget();
|
||||
|
||||
23
src/poet.cpp
23
src/poet.cpp
@ -254,8 +254,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) {
|
||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval"));
|
||||
params.stab_interval =
|
||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("stab_interval"));
|
||||
params.zero_abs =
|
||||
Rcpp::as<double>(global_rt_setup->operator[]("zero_abs"));
|
||||
params.zero_abs = Rcpp::as<double>(global_rt_setup->operator[]("zero_abs"));
|
||||
params.mape_threshold = Rcpp::as<std::vector<double>>(
|
||||
global_rt_setup->operator[]("mape_threshold"));
|
||||
params.ctrl_cell_ids = Rcpp::as<std::vector<uint32_t>>(
|
||||
@ -415,9 +414,9 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +
|
||||
std::to_string(maxiter));
|
||||
|
||||
control.writeErrorMetrics(iter, params.out_dir, chem.getField().GetProps());
|
||||
control.processCheckpoint(diffusion, iter, params.out_dir,
|
||||
chem.getField().GetProps());
|
||||
|
||||
// MSG();
|
||||
} // END SIMULATION LOOP
|
||||
|
||||
@ -434,13 +433,22 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
Rcpp::List diffusion_profiling;
|
||||
diffusion_profiling["simtime"] = diffusion.getTransportTime();
|
||||
|
||||
if (params.use_dht) {
|
||||
Rcpp::List ctrl_profiling;
|
||||
ctrl_profiling["compute_metrics_master"] = chem.GetMasterCtrlMetricsTime();
|
||||
ctrl_profiling["unshuffle_field_master"] = chem.GetMasterUnshuffleTime();
|
||||
ctrl_profiling["w_checkpoint_master"] = control.getWriteCheckpointTime();
|
||||
ctrl_profiling["r_checkpoint_master"] = control.getReadCheckpointTime();
|
||||
ctrl_profiling["write_stats"] = control.getWriteMetricsTime();
|
||||
ctrl_profiling["ctrl_logic_master"] = control.getUpdateCtrlLogicTime();
|
||||
ctrl_profiling["recv_data_master"] = chem.GetMasterRecvCtrlDataTime();
|
||||
ctrl_profiling["worker"] = Rcpp::wrap(chem.GetWorkerControlTimings());
|
||||
|
||||
// if (params.use_dht) {
|
||||
chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits());
|
||||
chem_profiling["dht_evictions"] = Rcpp::wrap(chem.GetWorkerDHTEvictions());
|
||||
chem_profiling["dht_get_time"] = Rcpp::wrap(chem.GetWorkerDHTGetTimings());
|
||||
chem_profiling["dht_fill_time"] =
|
||||
Rcpp::wrap(chem.GetWorkerDHTFillTimings());
|
||||
}
|
||||
chem_profiling["dht_fill_time"] = Rcpp::wrap(chem.GetWorkerDHTFillTimings());
|
||||
//}
|
||||
|
||||
if (params.use_interp) {
|
||||
chem_profiling["interp_w"] =
|
||||
@ -460,6 +468,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
profiling["simtime"] = dSimTime;
|
||||
profiling["chemistry"] = chem_profiling;
|
||||
profiling["diffusion"] = diffusion_profiling;
|
||||
profiling["control"] = ctrl_profiling;
|
||||
|
||||
chem.MasterLoopBreak();
|
||||
|
||||
|
||||
@ -52,6 +52,7 @@ struct RuntimeParameters {
|
||||
bool print_progress = false;
|
||||
std::uint32_t stab_interval = 0;
|
||||
std::uint32_t checkpoint_interval = 0;
|
||||
std::uint32_t max_rb = 0;
|
||||
double zero_abs = 0.0;
|
||||
std::vector<double> mape_threshold;
|
||||
std::vector<uint32_t> ctrl_cell_ids;
|
||||
|
||||
452
util/data_evaluation/RFun_Eval_qs2.R
Normal file
452
util/data_evaluation/RFun_Eval_qs2.R
Normal file
@ -0,0 +1,452 @@
|
||||
## Simple library of functions to assess and visualize the results of the coupled simulations
|
||||
## Modified to work with .qs2 files (qs2 package format)
|
||||
|
||||
## Time-stamp: "Last modified 2025-11-25"
|
||||
|
||||
require(qs2) ## for reading .qs2 files
|
||||
require(stringr)
|
||||
|
||||
# Note: RedModRphree, Rmufits, and Rcpp functions for DHT/PHT reading are kept
|
||||
# but you'll need those packages only if you use ReadAllDHT/ReadAllPHT functions
|
||||
|
||||
curdir <- dirname(sys.frame(1)$ofile) ##path.expand(".")
|
||||
|
||||
print(paste("RFun_Eval_qs2.R is in ", curdir))
|
||||
|
||||
## ============================================================================
|
||||
## NEW: Functions for reading .qs2 simulation outputs
|
||||
## ============================================================================
|
||||
|
||||
## function which reads all simulation results in a given directory (.qs2 format)
|
||||
ReadRTSims_qs2 <- function(dir) {
|
||||
pattern <- "^iter_.*\\.qs2$"
|
||||
files_full <- list.files(dir, pattern = pattern, full.names = TRUE)
|
||||
files_name <- list.files(dir, pattern = pattern, full.names = FALSE)
|
||||
|
||||
if (length(files_full) == 0) {
|
||||
warning(paste("No .qs2 files found in", dir, "with pattern", pattern))
|
||||
return(NULL)
|
||||
}
|
||||
|
||||
res <- lapply(files_full, qs2::qs_read)
|
||||
names(res) <- gsub(".qs2", "", files_name, perl = TRUE)
|
||||
|
||||
return(res[str_sort(names(res), numeric = TRUE)])
|
||||
}
|
||||
## Read a single .qs2 file
|
||||
ReadQS2 <- function(file) {
|
||||
if (!file.exists(file)) {
|
||||
stop(paste("File not found:", file))
|
||||
}
|
||||
qs2::qs_read(file)
|
||||
}
|
||||
|
||||
## Extract chemistry field data from .qs2 iteration file
|
||||
## Assumes structure similar to old .rds format with $C (chemistry) and $T (transport)
|
||||
ExtractChemistry <- function(qs2_data) {
|
||||
if ("C" %in% names(qs2_data)) {
|
||||
return(qs2_data$C)
|
||||
} else if (is.data.frame(qs2_data)) {
|
||||
return(qs2_data)
|
||||
} else {
|
||||
warning("Could not find chemistry data in expected format")
|
||||
return(qs2_data)
|
||||
}
|
||||
}
|
||||
|
||||
## Extract transport field data from .qs2 iteration file
|
||||
ExtractTransport <- function(qs2_data) {
|
||||
if ("T" %in% names(qs2_data)) {
|
||||
return(qs2_data$T)
|
||||
} else {
|
||||
warning("Could not find transport data in expected format")
|
||||
return(NULL)
|
||||
}
|
||||
}
|
||||
|
||||
## ============================================================================
|
||||
## ORIGINAL: DHT/PHT reading functions (kept for surrogate analysis)
|
||||
## ============================================================================
|
||||
|
||||
# Only load these if needed (requires Rcpp compilation)
|
||||
if (requireNamespace("Rcpp", quietly = TRUE) && file.exists(paste0(curdir, "/interpret_keys.cpp"))) {
|
||||
library(Rcpp)
|
||||
sourceCpp(file = paste0(curdir, "/interpret_keys.cpp"))
|
||||
|
||||
# Wrapper around previous sourced Rcpp function
|
||||
ConvertDHTKey <- function(value) {
|
||||
rcpp_key_convert(value)
|
||||
}
|
||||
|
||||
ConvertToUInt64 <- function(double_data) {
|
||||
rcpp_uint64_convert(double_data)
|
||||
}
|
||||
} else {
|
||||
if (!requireNamespace("Rcpp", quietly = TRUE)) {
|
||||
message("Note: Rcpp not available. DHT/PHT reading functions will not work.")
|
||||
}
|
||||
# Create dummy functions so the rest of the script doesn't break
|
||||
ConvertDHTKey <- function(value) {
|
||||
stop("Rcpp not available. Cannot convert DHT keys.")
|
||||
}
|
||||
ConvertToUInt64 <- function(double_data) {
|
||||
stop("Rcpp not available. Cannot convert to UInt64.")
|
||||
}
|
||||
}
|
||||
|
||||
## function which reads all successive DHT stored in a given directory
|
||||
ReadAllDHT <- function(dir, new_scheme = TRUE) {
|
||||
files_full <- list.files(dir, pattern="iter.*\\.dht$", full.names=TRUE)
|
||||
files_name <- list.files(dir, pattern="iter.*\\.dht$", full.names=FALSE)
|
||||
|
||||
if (length(files_full) == 0) {
|
||||
warning(paste("No .dht files found in", dir))
|
||||
return(NULL)
|
||||
}
|
||||
|
||||
res <- lapply(files_full, ReadDHT, new_scheme = new_scheme)
|
||||
names(res) <- gsub("\\.dht$","",files_name)
|
||||
return(res)
|
||||
}
|
||||
|
||||
## function which reads one .dht file and gives a matrix
|
||||
ReadDHT <- function(file, new_scheme = TRUE) {
|
||||
conn <- file(file, "rb") ## open for reading in binary mode
|
||||
if (!isSeekable(conn))
|
||||
stop("Connection not seekable")
|
||||
|
||||
## we first reposition ourselves to the end of the file...
|
||||
tmp <- seek(conn, where=0, origin = "end")
|
||||
## ... and then back to the origin so to store the length in bytes
|
||||
flen <- seek(conn, where=0, origin = "start")
|
||||
|
||||
## we read the first 2 integers (4 bytes each) containing dimensions in bytes
|
||||
dims <- readBin(conn, what="integer", n=2)
|
||||
|
||||
## compute dimensions of the data
|
||||
tots <- sum(dims)
|
||||
ncol <- tots/8
|
||||
nrow <- (flen - 8)/tots ## 8 here is 2*sizeof("int")
|
||||
buff <- readBin(conn, what="double", n=ncol*nrow)
|
||||
## close connection
|
||||
close(conn)
|
||||
res <- matrix(buff, nrow=nrow, ncol=ncol, byrow=TRUE)
|
||||
|
||||
if (new_scheme) {
|
||||
nkeys <- dims[1] / 8
|
||||
keys <- res[, 1:nkeys]
|
||||
|
||||
conv <- apply(keys, 2, ConvertDHTKey)
|
||||
res[, 1:nkeys] <- conv
|
||||
}
|
||||
|
||||
return(res)
|
||||
}
|
||||
|
||||
## function which reads all successive PHT stored in a given directory
|
||||
ReadAllPHT <- function(dir, with_info = FALSE) {
|
||||
files_full <- list.files(dir, pattern="iter.*\\.pht$", full.names=TRUE)
|
||||
files_name <- list.files(dir, pattern="iter.*\\.pht$", full.names=FALSE)
|
||||
|
||||
if (length(files_full) == 0) {
|
||||
warning(paste("No .pht files found in", dir))
|
||||
return(NULL)
|
||||
}
|
||||
|
||||
res <- lapply(files_full, ReadPHT, with_info = with_info)
|
||||
names(res) <- gsub("\\.pht$","",files_name)
|
||||
return(res)
|
||||
}
|
||||
|
||||
## function which reads one .pht file and gives a matrix
|
||||
ReadPHT <- function(file, with_info = FALSE) {
|
||||
conn <- file(file, "rb") ## open for reading in binary mode
|
||||
if (!isSeekable(conn))
|
||||
stop("Connection not seekable")
|
||||
|
||||
## we first reposition ourselves to the end of the file...
|
||||
tmp <- seek(conn, where=0, origin = "end")
|
||||
## ... and then back to the origin so to store the length in bytes
|
||||
flen <- seek(conn, where=0, origin = "start")
|
||||
|
||||
## we read the first 2 integers (4 bytes each) containing dimensions in bytes
|
||||
dims <- readBin(conn, what="integer", n=2)
|
||||
|
||||
## compute dimensions of the data
|
||||
tots <- sum(dims)
|
||||
ncol <- tots/8
|
||||
nrow <- (flen - 8)/tots ## 8 here is 2*sizeof("int")
|
||||
buff <- readBin(conn, what="double", n=ncol*nrow)
|
||||
## close connection
|
||||
close(conn)
|
||||
res <- matrix(buff, nrow=nrow, ncol=ncol, byrow=TRUE)
|
||||
|
||||
nkeys <- dims[1] / 8
|
||||
keys <- res[, 1:nkeys]
|
||||
|
||||
timesteps <- res[, nkeys + 1]
|
||||
conv <- apply(keys, 2, ConvertDHTKey)
|
||||
|
||||
ndata <- dims[2] / 8
|
||||
fill_rate <- ConvertToUInt64(res[, nkeys + 2])
|
||||
|
||||
buff <- c(conv, timesteps, fill_rate)
|
||||
|
||||
if (with_info) {
|
||||
ndata <- dims[2]/8
|
||||
visit_count <- ConvertToUInt64(res[, nkeys + ndata])
|
||||
buff <- c(buff, visit_count)
|
||||
}
|
||||
|
||||
res <- matrix(buff, nrow = nrow, byrow = FALSE)
|
||||
|
||||
return(res)
|
||||
}
|
||||
|
||||
## ============================================================================
|
||||
## PLOTTING and ANALYSIS functions (work with both .rds and .qs2 data)
|
||||
## ============================================================================
|
||||
|
||||
## Scatter plots of each variable in the iteration
|
||||
PlotScatter <- function(sam1, sam2, which=NULL, labs=c("NO DHT", "DHT"), pch=".", cols=3, ...) {
|
||||
if ((!is.data.frame(sam1)) & ("T" %in% names(sam1)))
|
||||
sam1 <- sam1$C
|
||||
if ((!is.data.frame(sam2)) & ("T" %in% names(sam2)))
|
||||
sam2 <- sam2$C
|
||||
if (is.numeric(which))
|
||||
inds <- which
|
||||
else if (is.character(which))
|
||||
inds <- match(which, colnames(sam1))
|
||||
else if (is.null(which))
|
||||
inds <- seq_along(colnames(sam1))
|
||||
|
||||
rows <- ceiling(length(inds) / cols)
|
||||
par(mfrow=c(rows, cols))
|
||||
a <- lapply(inds, function(x) {
|
||||
plot(sam1[,x], sam2[,x], main=colnames(sam1)[x], xlab=labs[1], ylab=labs[2], pch=pch, col="red", ...)
|
||||
abline(0,1, col="grey", lwd=1.5)
|
||||
})
|
||||
invisible()
|
||||
}
|
||||
|
||||
##### Some metrics for relative comparison
|
||||
|
||||
## Root Mean Square Error
|
||||
RMSE <- function(pred, obs)
|
||||
sqrt(mean((pred - obs)^2, na.rm = TRUE))
|
||||
|
||||
## Using range as norm
|
||||
RranRMSE <- function(pred, obs)
|
||||
sqrt(mean((pred - obs)^2, na.rm = TRUE))/abs(max(pred, na.rm = TRUE) - min(pred, na.rm = TRUE))
|
||||
|
||||
## Using max val as norm
|
||||
RmaxRMSE <- function(pred, obs)
|
||||
sqrt(mean((pred - obs)^2, na.rm = TRUE))/abs(max(pred, na.rm = TRUE))
|
||||
|
||||
## Using sd as norm
|
||||
RsdRMSE <- function(pred, obs)
|
||||
sqrt(mean((pred - obs)^2, na.rm = TRUE))/sd(pred, na.rm = TRUE)
|
||||
|
||||
## Using mean as norm
|
||||
RmeanRMSE <- function(pred, obs)
|
||||
sqrt(mean((pred - obs)^2, na.rm = TRUE))/mean(pred, na.rm = TRUE)
|
||||
|
||||
## Using mean as norm
|
||||
RAEmax <- function(pred, obs)
|
||||
mean(abs(pred - obs), na.rm = TRUE)/max(pred, na.rm = TRUE)
|
||||
|
||||
## Max absolute error
|
||||
MAE <- function(pred, obs)
|
||||
max(abs(pred - obs), na.rm = TRUE)
|
||||
|
||||
## Mean Absolute Percentage Error
|
||||
MAPE <- function(pred, obs)
|
||||
mean(abs((obs - pred) / obs) * 100, na.rm = TRUE)
|
||||
|
||||
## workhorse function for ComputeErrors and its use with mapply
|
||||
AppliedFun <- function(a, b, .fun) {
|
||||
# Extract chemistry data if needed
|
||||
if (!is.data.frame(a) && "C" %in% names(a)) a <- a$C
|
||||
if (!is.data.frame(b) && "C" %in% names(b)) b <- b$C
|
||||
|
||||
mapply(.fun, as.list(a), as.list(b))
|
||||
}
|
||||
|
||||
## Compute the diffs between two simulation, iter by iter,
|
||||
## with a given metric (passed in form of function name to this function)
|
||||
ComputeErrors <- function(sim1, sim2, FUN=RMSE) {
|
||||
if (length(sim1)!= length(sim2)) {
|
||||
cat("The simulations do not have the same length, subsetting to the shortest\n")
|
||||
a <- min(length(sim1), length(sim2))
|
||||
sim1 <- sim1[1:a]
|
||||
sim2 <- sim2[1:a]
|
||||
}
|
||||
if (!is.function(match.fun(FUN))) {
|
||||
stop("Invalid function\n")
|
||||
}
|
||||
|
||||
t(mapply(AppliedFun, sim1, sim2, MoreArgs=list(.fun=FUN)))
|
||||
}
|
||||
|
||||
## Function to display the error progress between 2 simulations
|
||||
ErrorProgress <- function(mat, ignore, colors, metric, ...) {
|
||||
if (is.null(mat)) {
|
||||
stop("Cannot plot: matrix is NULL")
|
||||
}
|
||||
|
||||
# Convert to matrix if it's a vector or data frame
|
||||
if (is.vector(mat)) {
|
||||
stop("Cannot plot: input is a vector (need at least 2 columns). Check that your data has multiple columns.")
|
||||
}
|
||||
if (is.data.frame(mat)) {
|
||||
mat <- as.matrix(mat)
|
||||
}
|
||||
|
||||
if (nrow(mat) == 0 || ncol(mat) == 0) {
|
||||
stop("Cannot plot: matrix is empty")
|
||||
}
|
||||
|
||||
if (missing(colors))
|
||||
colors <- sample(rainbow(ncol(mat)))
|
||||
|
||||
if (missing(metric))
|
||||
metric <- "Metric"
|
||||
|
||||
## if the optional argument "ignore" (a character vector) is
|
||||
## passed, we remove the matching column names
|
||||
if (!missing(ignore)) {
|
||||
to_remove <- match(ignore, colnames(mat))
|
||||
to_remove <- to_remove[!is.na(to_remove)] # Remove NAs
|
||||
if (length(to_remove) > 0) {
|
||||
mat <- mat[, -to_remove, drop = FALSE]
|
||||
colors <- colors[-to_remove]
|
||||
}
|
||||
}
|
||||
|
||||
yc <- mat[nrow(mat),]
|
||||
par(mar=c(5,4,2,8))
|
||||
matplot(mat, type="l", lty=1, lwd=2, col=colors, xlab="iteration", ylab=metric, ...)
|
||||
mtext(colnames(mat), side = 4, line = 0.5, outer = FALSE, at = yc, adj = 0, col = colors, las=2, cex=0.7)
|
||||
}
|
||||
|
||||
## Function which exports all simulations to ParaView's .vtu
|
||||
## Requires package RcppVTK
|
||||
ExportToParaview <- function(vtu, nameout, results) {
|
||||
if (!requireNamespace("RcppVTK", quietly = TRUE)) {
|
||||
stop("Package RcppVTK is required for this function")
|
||||
}
|
||||
require(RcppVTK)
|
||||
|
||||
n <- length(results)
|
||||
vars <- colnames(results[[1]])
|
||||
## strip eventually present ".vtu" from nameout
|
||||
nameout <- sub(".vtu", "", nameout, fixed=TRUE)
|
||||
namesteps <- paste0(nameout, ".", sprintf("%04d",seq(1,n)), ".vtu")
|
||||
for (step in seq_along(results)) {
|
||||
file.copy(from=vtu, to=namesteps[step], overwrite = TRUE)
|
||||
cat(paste("Saving step ", step, " in file ", namesteps[step], "\n"))
|
||||
ret <- ExportMatrixToVTU (fin=vtu, fout=namesteps[step], names=colnames(results[[step]]), mat=results[[step]])
|
||||
}
|
||||
invisible(ret)
|
||||
}
|
||||
|
||||
|
||||
## Version of Rmufits::PlotCartCellData with the ability to fix the
|
||||
## "breaks" for color coding of 2D simulations
|
||||
Plot2DCellData <- function (data, grid, nx, ny, contour = TRUE,
|
||||
nlevels = 12, breaks, palette = "heat.colors",
|
||||
rev.palette = TRUE, scale = TRUE, plot.axes=TRUE, ...) {
|
||||
if (!missing(grid)) {
|
||||
xc <- unique(sort(grid$cell$XCOORD))
|
||||
yc <- unique(sort(grid$cell$YCOORD))
|
||||
nx <- length(xc)
|
||||
ny <- length(yc)
|
||||
if (!length(data) == nx * ny)
|
||||
stop("Wrong nx, ny or grid")
|
||||
} else {
|
||||
xc <- seq(1, nx)
|
||||
yc <- seq(1, ny)
|
||||
}
|
||||
z <- matrix(round(data, 6), ncol = nx, nrow = ny, byrow = TRUE)
|
||||
pp <- t(z[rev(seq(1, nrow(z))), ])
|
||||
|
||||
if (missing(breaks)) {
|
||||
breaks <- pretty(data, n = nlevels)
|
||||
}
|
||||
|
||||
breakslen <- length(breaks)
|
||||
colors <- do.call(palette, list(n = breakslen - 1))
|
||||
if (rev.palette)
|
||||
colors <- rev(colors)
|
||||
if (scale) {
|
||||
par(mfrow = c(1, 2))
|
||||
nf <- layout(matrix(c(1, 2), 1, 2, byrow = TRUE), widths = c(4,
|
||||
1))
|
||||
}
|
||||
par(las = 1, mar = c(5, 5, 3, 1))
|
||||
image(xc, yc, pp, xlab = "X [m]", ylab = "Y[m]", las = 1, asp = 1,
|
||||
breaks = breaks, col = colors, axes = FALSE, ann=plot.axes,
|
||||
...)
|
||||
|
||||
if (plot.axes) {
|
||||
axis(1)
|
||||
axis(2)
|
||||
}
|
||||
if (contour)
|
||||
contour(unique(sort(xc)), unique(sort(yc)), pp, breaks = breaks,
|
||||
add = TRUE)
|
||||
if (scale) {
|
||||
par(las = 1, mar = c(5, 1, 5, 5))
|
||||
if (requireNamespace("Rmufits", quietly = TRUE)) {
|
||||
Rmufits::PlotImageScale(data, breaks = breaks, add.axis = FALSE,
|
||||
axis.pos = 4, col = colors)
|
||||
}
|
||||
axis(4, at = breaks)
|
||||
}
|
||||
invisible(pp)
|
||||
}
|
||||
|
||||
PlotAsMP4 <- function(data, nx, ny, to_plot, out_dir, name,
|
||||
contour = FALSE, scale = FALSE, framerate = 30) {
|
||||
sort_data <- data[str_sort(names(data), numeric = TRUE)]
|
||||
plot_data <- lapply(sort_data, function(x) {
|
||||
if (!is.data.frame(x) && "C" %in% names(x)) {
|
||||
return(x$C[[to_plot]])
|
||||
} else {
|
||||
return(x[[to_plot]])
|
||||
}
|
||||
})
|
||||
pad_size <- ceiling(log10(length(plot_data)))
|
||||
|
||||
dir.create(out_dir, showWarnings = FALSE)
|
||||
output_files <- paste0(out_dir, "/", name, "_%0", pad_size, "d.png")
|
||||
output_mp4 <- paste0(out_dir, "/", name, ".mp4")
|
||||
|
||||
png(output_files,
|
||||
width = 297, height = 210, units = "mm",
|
||||
res = 100
|
||||
)
|
||||
|
||||
for (i in 1:length(plot_data)) {
|
||||
if (requireNamespace("Rmufits", quietly = TRUE)) {
|
||||
Rmufits::PlotCartCellData(plot_data[[i]], nx = nx, ny = ny, contour = contour, scale = scale)
|
||||
} else {
|
||||
Plot2DCellData(plot_data[[i]], nx = nx, ny = ny, contour = contour, scale = scale)
|
||||
}
|
||||
}
|
||||
dev.off()
|
||||
|
||||
ffmpeg_command <- paste(
|
||||
"ffmpeg -y -framerate", framerate, "-i", output_files,
|
||||
"-c:v libx264 -crf 22", output_mp4
|
||||
)
|
||||
|
||||
system(ffmpeg_command)
|
||||
message(paste("Created video:", output_mp4))
|
||||
}
|
||||
|
||||
cat("\n=== RFun_Eval_qs2.R loaded successfully ===\n")
|
||||
cat("New functions for .qs2 files:\n")
|
||||
cat(" - ReadRTSims_qs2(dir) : Read all iteration .qs2 files\n")
|
||||
cat(" - ReadQS2(file) : Read single .qs2 file\n")
|
||||
cat("All other functions work as before!\n\n")
|
||||
50
util/data_evaluation/run_vis.R
Normal file
50
util/data_evaluation/run_vis.R
Normal file
@ -0,0 +1,50 @@
|
||||
# Load the new functions
|
||||
source("/mnt/beegfs/home/rastogi/poet/util/data_evaluation/RFun_Eval.R")
|
||||
|
||||
# Set base path
|
||||
base_dir <- "/mnt/beegfs/home/rastogi/poet/bin"
|
||||
|
||||
sim1 <- ReadRTSims(file.path(base_dir, "proto2_eps01_no_rb_v2"))
|
||||
sim2 <- ReadRTSims(file.path(base_dir, "proto2_eps0035_no_rb_v2"))
|
||||
|
||||
|
||||
# ========================================
|
||||
# Compare two simulations
|
||||
# ========================================
|
||||
rmse_errors <- ComputeErrors(sim1, sim2, FUN = RMSE)
|
||||
mape_errors <- ComputeErrors(sim1, sim2, FUN = MAPE)
|
||||
|
||||
# Print summary
|
||||
cat("RMSE errors computed for", ncol(rmse_errors), "variables\n")
|
||||
cat("MAPE errors computed for", ncol(mape_errors), "variables\n")
|
||||
cat("Number of iterations compared:", nrow(rmse_errors), "\n\n")
|
||||
|
||||
# Set output path explicitly
|
||||
output_pdf <- file.path(base_dir, "comparison_plots.pdf")
|
||||
cat("Saving plots to:", output_pdf, "\n")
|
||||
|
||||
# Save plots to PDF
|
||||
pdf(output_pdf, width = 10, height = 6)
|
||||
|
||||
# Plot error progression
|
||||
cat("Creating ErrorProgress plot...\n")
|
||||
ErrorProgress(rmse_errors, ignore = c("Charge"), metric = "RMSE")
|
||||
|
||||
# Scatter plot for specific iteration (if that iteration exists)
|
||||
if (length(sim1) >= 10 && length(sim2) >= 10) {
|
||||
cat("Creating scatter plot for iteration 10...\n")
|
||||
PlotScatter(sim1[[10]], sim2[[10]],
|
||||
labs = c("Proto2 Eps=0.01%", "Proto2 Eps=0.0035%"),
|
||||
which = c("Ca", "Mg", "Cl", "C"))
|
||||
} else {
|
||||
cat("Not enough iterations for scatter plot\n")
|
||||
}
|
||||
|
||||
dev.off()
|
||||
cat("Plots saved successfully to:", output_pdf, "\n")
|
||||
|
||||
# ========================================
|
||||
# DHT/PHT analysis (if you have snapshot files)
|
||||
# ========================================
|
||||
#dht_snaps <- ReadAllDHT("poet/bin/proto1_only_interp")
|
||||
#pht_snaps <- ReadAllPHT("poet/bin/proto1_only_interp")
|
||||
Loading…
x
Reference in New Issue
Block a user