Compare commits

...

2 Commits

11 changed files with 586 additions and 126 deletions

View File

@ -1,7 +1,7 @@
#!/bin/bash
#SBATCH --job-name=proto2_eps01_3_rb
#SBATCH --output=proto2_eps01_3_rb_%j.out
#SBATCH --error=proto2_eps01_3_rb_%j.err
#SBATCH --job-name=proto2_eps0035_no_rb_v2
#SBATCH --output=proto2_eps0035_no_rb_v2_%j.out
#SBATCH --error=proto2_eps0035_no_rb_v2%j.err
#SBATCH --partition=long
#SBATCH --nodes=6
#SBATCH --ntasks-per-node=24
@ -15,5 +15,5 @@ module purge
module load cmake gcc openmpi
#mpirun -n 144 ./poet dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_only_pqc
mpirun -n 144 ./poet --interp dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 proto2_eps01_3_rb
mpirun -n 144 ./poet --interp --rds dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 proto2_eps0035_no_rb_v2
#mpirun -n 144 ./poet --interp barite_fgcs_4_new/barite_fgcs_4_new_rt.R barite_fgcs_4_new/barite_fgcs_4_new.qs2 barite

Binary file not shown.

View File

@ -473,7 +473,7 @@ protected:
bool control_enabled{false};
bool stab_enabled{false};
std::unordered_set<uint32_t> ctrl_cell_ids;
std::vector<std::vector<double>> control_batch;
std::vector<std::vector<double>> ctrl_batch;
};
} // namespace poet

View File

@ -368,7 +368,7 @@ inline void poet::ChemistryModule::MasterRecvPkgs(worker_list_t &w_list,
std::vector<double> cell_output(
recv_buffer.begin() + this->prop_count * i,
recv_buffer.begin() + this->prop_count * (i + 1));
this->control_batch.push_back(std::move(cell_output));
this->ctrl_batch.push_back(std::move(cell_output));
}
break;
}
@ -443,13 +443,11 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
MPI_INT);
}
// if (control->shouldBcastFlags()) {
ftype = CHEM_CTRL_FLAGS;
PropagateFunctionType(ftype);
uint32_t ctrl_flags = buildCtrlFlags(this->dht_enabled, this->interp_enabled,
this->stab_enabled);
ChemBCast(&ctrl_flags, 1, MPI_INT);
//}
ftype = CHEM_WORK_LOOP;
PropagateFunctionType(ftype);
@ -467,8 +465,6 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count,
wp_sizes_vector.size());
// this->mpi_surr_buffer.resize(mpi_buffer.size());
/* setup local variables */
pkg_to_send = wp_sizes_vector.size();
pkg_to_recv = wp_sizes_vector.size();
@ -523,39 +519,37 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
/* do master stuff */
if (!this->control_batch.empty()) {
std::cout << "[Master] Processing " << this->control_batch.size()
if (!this->ctrl_batch.empty()) {
std::cout << "[Master] Processing " << this->ctrl_batch.size()
<< " control cells for comparison." << std::endl;
/* using mpi-buffer because we need cell-major layout*/
std::vector<std::vector<double>> surrogate_batch;
surrogate_batch.reserve(this->control_batch.size());
std::vector<std::vector<double>> sur_batch;
sur_batch.reserve(this->ctrl_batch.size());
for (const auto &element : this->control_batch) {
for (const auto &element : this->ctrl_batch) {
/* using mpi-buffer because we need cell-major layout*/
for (size_t i = 0; i < this->n_cells; i++) {
uint32_t curr_cell_id = mpi_buffer[this->prop_count * i];
if (curr_cell_id == element[0]) {
std::vector<double> surrogate_output(
std::vector<double> sur_output(
mpi_buffer.begin() + this->prop_count * i,
mpi_buffer.begin() + this->prop_count * (i + 1));
surrogate_batch.push_back(surrogate_output);
sur_batch.push_back(sur_output);
break;
}
}
}
metrics_a = MPI_Wtime();
control->computeErrorMetrics(this->control_batch, surrogate_batch,
control->computeErrorMetrics(this->ctrl_batch, sur_batch,
prop_names, n_cells);
control->writeErrorMetrics(ctrl_file_out_dir, prop_names);
metrics_b = MPI_Wtime();
this->metrics_t += metrics_b - metrics_a;
// Clear for next control iteration
this->control_batch.clear();
this->ctrl_batch.clear();
}
/* start time measurement of master chemistry */

View File

@ -7,23 +7,21 @@
poet::ControlModule::ControlModule(const ControlConfig &config_)
: config(config_) {}
void poet::ControlModule::beginIteration(ChemistryModule &chem,
const uint32_t &iter,
void poet::ControlModule::beginIteration(ChemistryModule &chem, uint32_t &iter,
const bool &dht_enabled,
const bool &interp_enabled) {
/* dht_enabled and inter_enabled are user settings set before startig the
* simulation*/
global_iteration = iter;
double prep_a, prep_b;
prep_a = MPI_Wtime();
global_iteration = iter;
updateStabilizationPhase(chem, dht_enabled, interp_enabled);
prep_b = MPI_Wtime();
this->prep_t += prep_b - prep_a;
}
/* Disables dht and/or interp during stabilzation phase */
void poet::ControlModule::updateStabilizationPhase(ChemistryModule &chem,
bool dht_enabled,
bool interp_enabled) {
@ -31,16 +29,15 @@ void poet::ControlModule::updateStabilizationPhase(ChemistryModule &chem,
if (disable_surr_counter > 0) {
--disable_surr_counter;
flush_request = false;
MSG("Rollback counter: " + std::to_string(disable_surr_counter));
std::cout << "Rollback counter: " << std::to_string(disable_surr_counter)
<< std::endl;
} else {
rollback_enabled = false;
MSG("Rollback stabilization complete, re-enabling surrogates");
std::cout << "Rollback stabilization complete, re-enabling surrogate."
<< std::endl;
}
}
bool prev_stab_state = chem.GetStabEnabled();
// user requested DHT/INTEP? keep them disabled but enable warmup-phase so
if (global_iteration <= config.stab_interval || rollback_enabled) {
chem.SetStabEnabled(true);
chem.SetDhtEnabled(false);
@ -50,27 +47,19 @@ void poet::ControlModule::updateStabilizationPhase(ChemistryModule &chem,
chem.SetDhtEnabled(dht_enabled);
chem.SetInterpEnabled(interp_enabled);
}
// Mark that we need to broadcast flags if stab state changed
if (prev_stab_state != chem.GetStabEnabled()) {
stab_phase_ended = true;
}
}
void poet::ControlModule::writeCheckpoint(DiffusionModule &diffusion,
uint32_t &iter,
const std::string &out_dir) {
double w_check_a, w_check_b;
w_check_a = MPI_Wtime();
if (global_iteration % config.checkpoint_interval == 0) {
double w_check_a = MPI_Wtime();
write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
{.field = diffusion.getField(), .iteration = iter});
double w_check_b = MPI_Wtime();
this->w_check_t += w_check_b - w_check_a;
last_checkpoint_written = iter;
}
w_check_b = MPI_Wtime();
this->w_check_t += w_check_b - w_check_a;
last_checkpoint_written = iter;
}
void poet::ControlModule::readCheckpoint(DiffusionModule &diffusion,
@ -90,18 +79,15 @@ void poet::ControlModule::readCheckpoint(DiffusionModule &diffusion,
}
void poet::ControlModule::writeErrorMetrics(
const std::string &out_dir, const std::vector<std::string> &species) {
double stats_a, stats_b;
uint32_t &iter, const std::string &out_dir,
const std::vector<std::string> &species) {
stats_a = MPI_Wtime();
double stats_a = MPI_Wtime();
writeSpeciesStatsToCSV(metrics_history, species, out_dir,
"species_overview.csv");
// writeCellStatsToCSV(cell_metrics_history, species, out_dir,
// "cell_overview.csv");
write_metrics(cell_metrics_history, species, out_dir,
"metrics_overview.hdf5");
stats_b = MPI_Wtime();
double stats_b = MPI_Wtime();
this->stats_t += stats_b - stats_a;
}
@ -109,19 +95,6 @@ uint32_t poet::ControlModule::getRollbackIter() {
uint32_t last_iter = ((global_iteration - 1) / config.checkpoint_interval) *
config.checkpoint_interval;
/*
uint32_t rollback_iter = (last_iter <= last_checkpoint_written)
? last_iter
: last_checkpoint_written;
MSG("getRollbackIter: global_iteration=" + std::to_string(global_iteration) +
", checkpoint_interval=" + std::to_string(config.checkpoint_interval) +
", last_checkpoint_written=" + std::to_string(last_checkpoint_written) +
", returning=" + std::to_string(last_checkpoint_written));
return last_checkpoint_written;
*/
return last_iter;
}
@ -129,55 +102,47 @@ std::optional<uint32_t> poet::ControlModule::getRollbackTarget(
const std::vector<std::string> &species) {
double r_check_a, r_check_b;
if (metrics_history.empty()) {
MSG("No error history yet, skipping rollback check.");
rollback_enabled = false;
return std::nullopt;
}
// Skip threshold checking if already in rollback/stabilization phase
if (rollback_enabled) {
/* Skip threshold checking if already in stabilization phase*/
if (metrics_history.empty() || rollback_enabled) {
return std::nullopt;
}
const auto &s_mape = metrics_history.back().mape;
const auto &s_hist = metrics_history.back();
/* skipping cell_id and id */
for (size_t sp_i = 2; sp_i < species.size(); sp_i++) {
// skip charge
if (s_mape[sp_i] == 0 || sp_i == 4) {
/* check bounds of threshold vector*/
if (sp_i >= config.mape_threshold.size()) {
std::cerr << "Warning: No threshold defined for species " << species[sp_i]
<< " at index " << std::to_string(sp_i) << std::endl;
continue;
}
if (s_mape[sp_i] > config.mape_threshold[sp_i]) {
if (last_checkpoint_written == 0) {
MSG(" Threshold exceeded but no checkpoint exists yet.");
return std::nullopt;
}
const auto &c_mape = cell_metrics_history.back().mape;
const auto &c_id = cell_metrics_history.back().id;
if (s_hist.mape[sp_i] > config.mape_threshold[sp_i]) {
const auto &c_hist = cell_metrics_history.back();
auto max_it = std::max_element(
c_mape.begin(), c_mape.end(),
c_hist.mape.begin(), c_hist.mape.end(),
[sp_i](const auto &a, const auto &b) { return a[sp_i] < b[sp_i]; });
size_t max_idx = std::distance(c_mape.begin(), max_it);
uint32_t cell_id = c_id[max_idx];
size_t max_idx = std::distance(c_hist.mape.begin(), max_it);
uint32_t cell_id = c_hist.id[max_idx];
double cell_mape = (*max_it)[sp_i];
rollback_enabled = true;
flush_request = true;
MSG("Threshold exceeded for " + species[sp_i] +
" with species-level MAPE = " + std::to_string(s_mape[sp_i]) +
" exceeding threshold = " +
std::to_string(config.mape_threshold[sp_i]) + ". Worst cell: ID=" +
std::to_string(cell_id) + " with MAPE=" + std::to_string(cell_mape));
std::cout << "Threshold exceeded for " << species[sp_i]
<< " with species-level MAPE = " << std::to_string(s_hist.mape[sp_i])
<< " exceeding threshold = "
<< std::to_string(config.mape_threshold[sp_i])
<< ". Worst cell: ID=" << std::to_string(cell_id)
<< ", species= " << species[sp_i]
<< " with MAPE=" << std::to_string(cell_mape) << std::endl;
return getRollbackIter();
}
}
rollback_enabled = false;
flush_request = false;
return std::nullopt;
}
@ -186,11 +151,6 @@ void poet::ControlModule::computeErrorMetrics(
std::vector<std::vector<double>> &surrogate_values,
const std::vector<std::string> &species, const uint32_t size_per_prop) {
// Skip metric computation if already in rollback/stabilization phase
if (rollback_enabled) {
return;
}
const uint32_t n_cells = reference_values.size();
const uint32_t n_species = species.size();
const double ZERO_ABS = config.zero_abs;
@ -256,7 +216,11 @@ void poet::ControlModule::processCheckpoint(
DiffusionModule &diffusion, uint32_t &current_iter,
const std::string &out_dir, const std::vector<std::string> &species) {
if (flush_request && rollback_count < 3) {
// Use max_rollbacks from config, default to 3 if not set
// uint32_t max_rollbacks =
// (config.max_rollbacks > 0) ? config.max_rollbacks : 3;
if (flush_request /* && rollback_count < 3 */) {
uint32_t target = getRollbackIter();
readCheckpoint(diffusion, current_iter, target, out_dir);
@ -264,21 +228,10 @@ void poet::ControlModule::processCheckpoint(
rollback_count++;
disable_surr_counter = config.stab_interval;
MSG("Restored checkpoint " + std::to_string(target) +
", surrogates disabled for " + std::to_string(config.stab_interval));
std::cout << "Restored checkpoint " << std::to_string(target)
<< ", surrogate disabled for "
<< std::to_string(config.stab_interval) << std::endl;
} else {
writeCheckpoint(diffusion, global_iteration, out_dir);
}
}
bool poet::ControlModule::shouldBcastFlags() {
if (global_iteration == 1) {
return true;
}
if (stab_phase_ended) {
return true;
}
return false;
}

View File

@ -19,7 +19,8 @@ class DiffusionModule;
struct ControlConfig {
uint32_t stab_interval = 0;
uint32_t checkpoint_interval = 0;
uint32_t checkpoint_interval = 0; // How often to write metrics files
//uint32_t max_rb = 0; // Maximum number of rollbacks allowed
double zero_abs = 0.0;
std::vector<double> mape_threshold;
};
@ -54,10 +55,10 @@ class ControlModule {
public:
explicit ControlModule(const ControlConfig &config);
void beginIteration(ChemistryModule &chem, const uint32_t &iter,
void beginIteration(ChemistryModule &chem, uint32_t &iter,
const bool &dht_enabled, const bool &interp_enaled);
void writeErrorMetrics(const std::string &out_dir,
void writeErrorMetrics(uint32_t &iter, const std::string &out_dir,
const std::vector<std::string> &species);
std::optional<uint32_t> getRollbackTarget();

View File

@ -254,8 +254,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
Rcpp::as<uint32_t>(global_rt_setup->operator[]("checkpoint_interval"));
params.stab_interval =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("stab_interval"));
params.zero_abs =
Rcpp::as<double>(global_rt_setup->operator[]("zero_abs"));
params.zero_abs = Rcpp::as<double>(global_rt_setup->operator[]("zero_abs"));
params.mape_threshold = Rcpp::as<std::vector<double>>(
global_rt_setup->operator[]("mape_threshold"));
params.ctrl_cell_ids = Rcpp::as<std::vector<uint32_t>>(
@ -415,9 +414,9 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +
std::to_string(maxiter));
control.writeErrorMetrics(iter, params.out_dir, chem.getField().GetProps());
control.processCheckpoint(diffusion, iter, params.out_dir,
chem.getField().GetProps());
// MSG();
} // END SIMULATION LOOP
@ -434,13 +433,22 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
Rcpp::List diffusion_profiling;
diffusion_profiling["simtime"] = diffusion.getTransportTime();
if (params.use_dht) {
chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits());
chem_profiling["dht_evictions"] = Rcpp::wrap(chem.GetWorkerDHTEvictions());
chem_profiling["dht_get_time"] = Rcpp::wrap(chem.GetWorkerDHTGetTimings());
chem_profiling["dht_fill_time"] =
Rcpp::wrap(chem.GetWorkerDHTFillTimings());
}
Rcpp::List ctrl_profiling;
ctrl_profiling["compute_metrics_master"] = chem.GetMasterCtrlMetricsTime();
ctrl_profiling["unshuffle_field_master"] = chem.GetMasterUnshuffleTime();
ctrl_profiling["w_checkpoint_master"] = control.getWriteCheckpointTime();
ctrl_profiling["r_checkpoint_master"] = control.getReadCheckpointTime();
ctrl_profiling["write_stats"] = control.getWriteMetricsTime();
ctrl_profiling["ctrl_logic_master"] = control.getUpdateCtrlLogicTime();
ctrl_profiling["recv_data_master"] = chem.GetMasterRecvCtrlDataTime();
ctrl_profiling["worker"] = Rcpp::wrap(chem.GetWorkerControlTimings());
// if (params.use_dht) {
chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits());
chem_profiling["dht_evictions"] = Rcpp::wrap(chem.GetWorkerDHTEvictions());
chem_profiling["dht_get_time"] = Rcpp::wrap(chem.GetWorkerDHTGetTimings());
chem_profiling["dht_fill_time"] = Rcpp::wrap(chem.GetWorkerDHTFillTimings());
//}
if (params.use_interp) {
chem_profiling["interp_w"] =
@ -460,6 +468,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
profiling["simtime"] = dSimTime;
profiling["chemistry"] = chem_profiling;
profiling["diffusion"] = diffusion_profiling;
profiling["control"] = ctrl_profiling;
chem.MasterLoopBreak();

View File

@ -52,6 +52,7 @@ struct RuntimeParameters {
bool print_progress = false;
std::uint32_t stab_interval = 0;
std::uint32_t checkpoint_interval = 0;
std::uint32_t max_rb = 0;
double zero_abs = 0.0;
std::vector<double> mape_threshold;
std::vector<uint32_t> ctrl_cell_ids;

View File

@ -0,0 +1,452 @@
## Simple library of functions to assess and visualize the results of the coupled simulations
## Modified to work with .qs2 files (qs2 package format)
## Time-stamp: "Last modified 2025-11-25"
require(qs2) ## for reading .qs2 files
require(stringr)
# Note: RedModRphree, Rmufits, and Rcpp functions for DHT/PHT reading are kept
# but you'll need those packages only if you use ReadAllDHT/ReadAllPHT functions
curdir <- dirname(sys.frame(1)$ofile) ##path.expand(".")
print(paste("RFun_Eval_qs2.R is in ", curdir))
## ============================================================================
## NEW: Functions for reading .qs2 simulation outputs
## ============================================================================
## function which reads all simulation results in a given directory (.qs2 format)
ReadRTSims_qs2 <- function(dir) {
pattern <- "^iter_.*\\.qs2$"
files_full <- list.files(dir, pattern = pattern, full.names = TRUE)
files_name <- list.files(dir, pattern = pattern, full.names = FALSE)
if (length(files_full) == 0) {
warning(paste("No .qs2 files found in", dir, "with pattern", pattern))
return(NULL)
}
res <- lapply(files_full, qs2::qs_read)
names(res) <- gsub(".qs2", "", files_name, perl = TRUE)
return(res[str_sort(names(res), numeric = TRUE)])
}
## Read a single .qs2 file
ReadQS2 <- function(file) {
if (!file.exists(file)) {
stop(paste("File not found:", file))
}
qs2::qs_read(file)
}
## Extract chemistry field data from .qs2 iteration file
## Assumes structure similar to old .rds format with $C (chemistry) and $T (transport)
ExtractChemistry <- function(qs2_data) {
if ("C" %in% names(qs2_data)) {
return(qs2_data$C)
} else if (is.data.frame(qs2_data)) {
return(qs2_data)
} else {
warning("Could not find chemistry data in expected format")
return(qs2_data)
}
}
## Extract transport field data from .qs2 iteration file
ExtractTransport <- function(qs2_data) {
if ("T" %in% names(qs2_data)) {
return(qs2_data$T)
} else {
warning("Could not find transport data in expected format")
return(NULL)
}
}
## ============================================================================
## ORIGINAL: DHT/PHT reading functions (kept for surrogate analysis)
## ============================================================================
# Only load these if needed (requires Rcpp compilation)
if (requireNamespace("Rcpp", quietly = TRUE) && file.exists(paste0(curdir, "/interpret_keys.cpp"))) {
library(Rcpp)
sourceCpp(file = paste0(curdir, "/interpret_keys.cpp"))
# Wrapper around previous sourced Rcpp function
ConvertDHTKey <- function(value) {
rcpp_key_convert(value)
}
ConvertToUInt64 <- function(double_data) {
rcpp_uint64_convert(double_data)
}
} else {
if (!requireNamespace("Rcpp", quietly = TRUE)) {
message("Note: Rcpp not available. DHT/PHT reading functions will not work.")
}
# Create dummy functions so the rest of the script doesn't break
ConvertDHTKey <- function(value) {
stop("Rcpp not available. Cannot convert DHT keys.")
}
ConvertToUInt64 <- function(double_data) {
stop("Rcpp not available. Cannot convert to UInt64.")
}
}
## function which reads all successive DHT stored in a given directory
ReadAllDHT <- function(dir, new_scheme = TRUE) {
files_full <- list.files(dir, pattern="iter.*\\.dht$", full.names=TRUE)
files_name <- list.files(dir, pattern="iter.*\\.dht$", full.names=FALSE)
if (length(files_full) == 0) {
warning(paste("No .dht files found in", dir))
return(NULL)
}
res <- lapply(files_full, ReadDHT, new_scheme = new_scheme)
names(res) <- gsub("\\.dht$","",files_name)
return(res)
}
## function which reads one .dht file and gives a matrix
ReadDHT <- function(file, new_scheme = TRUE) {
conn <- file(file, "rb") ## open for reading in binary mode
if (!isSeekable(conn))
stop("Connection not seekable")
## we first reposition ourselves to the end of the file...
tmp <- seek(conn, where=0, origin = "end")
## ... and then back to the origin so to store the length in bytes
flen <- seek(conn, where=0, origin = "start")
## we read the first 2 integers (4 bytes each) containing dimensions in bytes
dims <- readBin(conn, what="integer", n=2)
## compute dimensions of the data
tots <- sum(dims)
ncol <- tots/8
nrow <- (flen - 8)/tots ## 8 here is 2*sizeof("int")
buff <- readBin(conn, what="double", n=ncol*nrow)
## close connection
close(conn)
res <- matrix(buff, nrow=nrow, ncol=ncol, byrow=TRUE)
if (new_scheme) {
nkeys <- dims[1] / 8
keys <- res[, 1:nkeys]
conv <- apply(keys, 2, ConvertDHTKey)
res[, 1:nkeys] <- conv
}
return(res)
}
## function which reads all successive PHT stored in a given directory
ReadAllPHT <- function(dir, with_info = FALSE) {
files_full <- list.files(dir, pattern="iter.*\\.pht$", full.names=TRUE)
files_name <- list.files(dir, pattern="iter.*\\.pht$", full.names=FALSE)
if (length(files_full) == 0) {
warning(paste("No .pht files found in", dir))
return(NULL)
}
res <- lapply(files_full, ReadPHT, with_info = with_info)
names(res) <- gsub("\\.pht$","",files_name)
return(res)
}
## function which reads one .pht file and gives a matrix
ReadPHT <- function(file, with_info = FALSE) {
conn <- file(file, "rb") ## open for reading in binary mode
if (!isSeekable(conn))
stop("Connection not seekable")
## we first reposition ourselves to the end of the file...
tmp <- seek(conn, where=0, origin = "end")
## ... and then back to the origin so to store the length in bytes
flen <- seek(conn, where=0, origin = "start")
## we read the first 2 integers (4 bytes each) containing dimensions in bytes
dims <- readBin(conn, what="integer", n=2)
## compute dimensions of the data
tots <- sum(dims)
ncol <- tots/8
nrow <- (flen - 8)/tots ## 8 here is 2*sizeof("int")
buff <- readBin(conn, what="double", n=ncol*nrow)
## close connection
close(conn)
res <- matrix(buff, nrow=nrow, ncol=ncol, byrow=TRUE)
nkeys <- dims[1] / 8
keys <- res[, 1:nkeys]
timesteps <- res[, nkeys + 1]
conv <- apply(keys, 2, ConvertDHTKey)
ndata <- dims[2] / 8
fill_rate <- ConvertToUInt64(res[, nkeys + 2])
buff <- c(conv, timesteps, fill_rate)
if (with_info) {
ndata <- dims[2]/8
visit_count <- ConvertToUInt64(res[, nkeys + ndata])
buff <- c(buff, visit_count)
}
res <- matrix(buff, nrow = nrow, byrow = FALSE)
return(res)
}
## ============================================================================
## PLOTTING and ANALYSIS functions (work with both .rds and .qs2 data)
## ============================================================================
## Scatter plots of each variable in the iteration
PlotScatter <- function(sam1, sam2, which=NULL, labs=c("NO DHT", "DHT"), pch=".", cols=3, ...) {
if ((!is.data.frame(sam1)) & ("T" %in% names(sam1)))
sam1 <- sam1$C
if ((!is.data.frame(sam2)) & ("T" %in% names(sam2)))
sam2 <- sam2$C
if (is.numeric(which))
inds <- which
else if (is.character(which))
inds <- match(which, colnames(sam1))
else if (is.null(which))
inds <- seq_along(colnames(sam1))
rows <- ceiling(length(inds) / cols)
par(mfrow=c(rows, cols))
a <- lapply(inds, function(x) {
plot(sam1[,x], sam2[,x], main=colnames(sam1)[x], xlab=labs[1], ylab=labs[2], pch=pch, col="red", ...)
abline(0,1, col="grey", lwd=1.5)
})
invisible()
}
##### Some metrics for relative comparison
## Root Mean Square Error
RMSE <- function(pred, obs)
sqrt(mean((pred - obs)^2, na.rm = TRUE))
## Using range as norm
RranRMSE <- function(pred, obs)
sqrt(mean((pred - obs)^2, na.rm = TRUE))/abs(max(pred, na.rm = TRUE) - min(pred, na.rm = TRUE))
## Using max val as norm
RmaxRMSE <- function(pred, obs)
sqrt(mean((pred - obs)^2, na.rm = TRUE))/abs(max(pred, na.rm = TRUE))
## Using sd as norm
RsdRMSE <- function(pred, obs)
sqrt(mean((pred - obs)^2, na.rm = TRUE))/sd(pred, na.rm = TRUE)
## Using mean as norm
RmeanRMSE <- function(pred, obs)
sqrt(mean((pred - obs)^2, na.rm = TRUE))/mean(pred, na.rm = TRUE)
## Using mean as norm
RAEmax <- function(pred, obs)
mean(abs(pred - obs), na.rm = TRUE)/max(pred, na.rm = TRUE)
## Max absolute error
MAE <- function(pred, obs)
max(abs(pred - obs), na.rm = TRUE)
## Mean Absolute Percentage Error
MAPE <- function(pred, obs)
mean(abs((obs - pred) / obs) * 100, na.rm = TRUE)
## workhorse function for ComputeErrors and its use with mapply
AppliedFun <- function(a, b, .fun) {
# Extract chemistry data if needed
if (!is.data.frame(a) && "C" %in% names(a)) a <- a$C
if (!is.data.frame(b) && "C" %in% names(b)) b <- b$C
mapply(.fun, as.list(a), as.list(b))
}
## Compute the diffs between two simulation, iter by iter,
## with a given metric (passed in form of function name to this function)
ComputeErrors <- function(sim1, sim2, FUN=RMSE) {
if (length(sim1)!= length(sim2)) {
cat("The simulations do not have the same length, subsetting to the shortest\n")
a <- min(length(sim1), length(sim2))
sim1 <- sim1[1:a]
sim2 <- sim2[1:a]
}
if (!is.function(match.fun(FUN))) {
stop("Invalid function\n")
}
t(mapply(AppliedFun, sim1, sim2, MoreArgs=list(.fun=FUN)))
}
## Function to display the error progress between 2 simulations
ErrorProgress <- function(mat, ignore, colors, metric, ...) {
if (is.null(mat)) {
stop("Cannot plot: matrix is NULL")
}
# Convert to matrix if it's a vector or data frame
if (is.vector(mat)) {
stop("Cannot plot: input is a vector (need at least 2 columns). Check that your data has multiple columns.")
}
if (is.data.frame(mat)) {
mat <- as.matrix(mat)
}
if (nrow(mat) == 0 || ncol(mat) == 0) {
stop("Cannot plot: matrix is empty")
}
if (missing(colors))
colors <- sample(rainbow(ncol(mat)))
if (missing(metric))
metric <- "Metric"
## if the optional argument "ignore" (a character vector) is
## passed, we remove the matching column names
if (!missing(ignore)) {
to_remove <- match(ignore, colnames(mat))
to_remove <- to_remove[!is.na(to_remove)] # Remove NAs
if (length(to_remove) > 0) {
mat <- mat[, -to_remove, drop = FALSE]
colors <- colors[-to_remove]
}
}
yc <- mat[nrow(mat),]
par(mar=c(5,4,2,8))
matplot(mat, type="l", lty=1, lwd=2, col=colors, xlab="iteration", ylab=metric, ...)
mtext(colnames(mat), side = 4, line = 0.5, outer = FALSE, at = yc, adj = 0, col = colors, las=2, cex=0.7)
}
## Function which exports all simulations to ParaView's .vtu
## Requires package RcppVTK
ExportToParaview <- function(vtu, nameout, results) {
if (!requireNamespace("RcppVTK", quietly = TRUE)) {
stop("Package RcppVTK is required for this function")
}
require(RcppVTK)
n <- length(results)
vars <- colnames(results[[1]])
## strip eventually present ".vtu" from nameout
nameout <- sub(".vtu", "", nameout, fixed=TRUE)
namesteps <- paste0(nameout, ".", sprintf("%04d",seq(1,n)), ".vtu")
for (step in seq_along(results)) {
file.copy(from=vtu, to=namesteps[step], overwrite = TRUE)
cat(paste("Saving step ", step, " in file ", namesteps[step], "\n"))
ret <- ExportMatrixToVTU (fin=vtu, fout=namesteps[step], names=colnames(results[[step]]), mat=results[[step]])
}
invisible(ret)
}
## Version of Rmufits::PlotCartCellData with the ability to fix the
## "breaks" for color coding of 2D simulations
Plot2DCellData <- function (data, grid, nx, ny, contour = TRUE,
nlevels = 12, breaks, palette = "heat.colors",
rev.palette = TRUE, scale = TRUE, plot.axes=TRUE, ...) {
if (!missing(grid)) {
xc <- unique(sort(grid$cell$XCOORD))
yc <- unique(sort(grid$cell$YCOORD))
nx <- length(xc)
ny <- length(yc)
if (!length(data) == nx * ny)
stop("Wrong nx, ny or grid")
} else {
xc <- seq(1, nx)
yc <- seq(1, ny)
}
z <- matrix(round(data, 6), ncol = nx, nrow = ny, byrow = TRUE)
pp <- t(z[rev(seq(1, nrow(z))), ])
if (missing(breaks)) {
breaks <- pretty(data, n = nlevels)
}
breakslen <- length(breaks)
colors <- do.call(palette, list(n = breakslen - 1))
if (rev.palette)
colors <- rev(colors)
if (scale) {
par(mfrow = c(1, 2))
nf <- layout(matrix(c(1, 2), 1, 2, byrow = TRUE), widths = c(4,
1))
}
par(las = 1, mar = c(5, 5, 3, 1))
image(xc, yc, pp, xlab = "X [m]", ylab = "Y[m]", las = 1, asp = 1,
breaks = breaks, col = colors, axes = FALSE, ann=plot.axes,
...)
if (plot.axes) {
axis(1)
axis(2)
}
if (contour)
contour(unique(sort(xc)), unique(sort(yc)), pp, breaks = breaks,
add = TRUE)
if (scale) {
par(las = 1, mar = c(5, 1, 5, 5))
if (requireNamespace("Rmufits", quietly = TRUE)) {
Rmufits::PlotImageScale(data, breaks = breaks, add.axis = FALSE,
axis.pos = 4, col = colors)
}
axis(4, at = breaks)
}
invisible(pp)
}
PlotAsMP4 <- function(data, nx, ny, to_plot, out_dir, name,
contour = FALSE, scale = FALSE, framerate = 30) {
sort_data <- data[str_sort(names(data), numeric = TRUE)]
plot_data <- lapply(sort_data, function(x) {
if (!is.data.frame(x) && "C" %in% names(x)) {
return(x$C[[to_plot]])
} else {
return(x[[to_plot]])
}
})
pad_size <- ceiling(log10(length(plot_data)))
dir.create(out_dir, showWarnings = FALSE)
output_files <- paste0(out_dir, "/", name, "_%0", pad_size, "d.png")
output_mp4 <- paste0(out_dir, "/", name, ".mp4")
png(output_files,
width = 297, height = 210, units = "mm",
res = 100
)
for (i in 1:length(plot_data)) {
if (requireNamespace("Rmufits", quietly = TRUE)) {
Rmufits::PlotCartCellData(plot_data[[i]], nx = nx, ny = ny, contour = contour, scale = scale)
} else {
Plot2DCellData(plot_data[[i]], nx = nx, ny = ny, contour = contour, scale = scale)
}
}
dev.off()
ffmpeg_command <- paste(
"ffmpeg -y -framerate", framerate, "-i", output_files,
"-c:v libx264 -crf 22", output_mp4
)
system(ffmpeg_command)
message(paste("Created video:", output_mp4))
}
cat("\n=== RFun_Eval_qs2.R loaded successfully ===\n")
cat("New functions for .qs2 files:\n")
cat(" - ReadRTSims_qs2(dir) : Read all iteration .qs2 files\n")
cat(" - ReadQS2(file) : Read single .qs2 file\n")
cat("All other functions work as before!\n\n")

View File

@ -0,0 +1,50 @@
# Load the new functions
source("/mnt/beegfs/home/rastogi/poet/util/data_evaluation/RFun_Eval.R")
# Set base path
base_dir <- "/mnt/beegfs/home/rastogi/poet/bin"
sim1 <- ReadRTSims(file.path(base_dir, "proto2_eps01_no_rb_v2"))
sim2 <- ReadRTSims(file.path(base_dir, "proto2_eps0035_no_rb_v2"))
# ========================================
# Compare two simulations
# ========================================
rmse_errors <- ComputeErrors(sim1, sim2, FUN = RMSE)
mape_errors <- ComputeErrors(sim1, sim2, FUN = MAPE)
# Print summary
cat("RMSE errors computed for", ncol(rmse_errors), "variables\n")
cat("MAPE errors computed for", ncol(mape_errors), "variables\n")
cat("Number of iterations compared:", nrow(rmse_errors), "\n\n")
# Set output path explicitly
output_pdf <- file.path(base_dir, "comparison_plots.pdf")
cat("Saving plots to:", output_pdf, "\n")
# Save plots to PDF
pdf(output_pdf, width = 10, height = 6)
# Plot error progression
cat("Creating ErrorProgress plot...\n")
ErrorProgress(rmse_errors, ignore = c("Charge"), metric = "RMSE")
# Scatter plot for specific iteration (if that iteration exists)
if (length(sim1) >= 10 && length(sim2) >= 10) {
cat("Creating scatter plot for iteration 10...\n")
PlotScatter(sim1[[10]], sim2[[10]],
labs = c("Proto2 Eps=0.01%", "Proto2 Eps=0.0035%"),
which = c("Ca", "Mg", "Cl", "C"))
} else {
cat("Not enough iterations for scatter plot\n")
}
dev.off()
cat("Plots saved successfully to:", output_pdf, "\n")
# ========================================
# DHT/PHT analysis (if you have snapshot files)
# ========================================
#dht_snaps <- ReadAllDHT("poet/bin/proto1_only_interp")
#pht_snaps <- ReadAllPHT("poet/bin/proto1_only_interp")