diff --git a/share/poet/barite/barite_het.qs2 b/share/poet/barite/barite_het.qs2 index c5b4ec46b..1534d8aa0 100644 Binary files a/share/poet/barite/barite_het.qs2 and b/share/poet/barite/barite_het.qs2 differ diff --git a/share/poet/surfex/PoetEGU_surfex_500.qs2 b/share/poet/surfex/PoetEGU_surfex_500.qs2 index 0f5ecd578..19c7da936 100644 Binary files a/share/poet/surfex/PoetEGU_surfex_500.qs2 and b/share/poet/surfex/PoetEGU_surfex_500.qs2 differ diff --git a/src/Chemistry/ChemistryModule.hpp b/src/Chemistry/ChemistryModule.hpp index 2f1cae78e..63e7d656c 100644 --- a/src/Chemistry/ChemistryModule.hpp +++ b/src/Chemistry/ChemistryModule.hpp @@ -473,7 +473,7 @@ protected: bool control_enabled{false}; bool stab_enabled{false}; std::unordered_set ctrl_cell_ids; - std::vector> control_batch; + std::vector> ctrl_batch; }; } // namespace poet diff --git a/src/Chemistry/MasterFunctions.cpp b/src/Chemistry/MasterFunctions.cpp index bc9a00052..c7846ed3c 100644 --- a/src/Chemistry/MasterFunctions.cpp +++ b/src/Chemistry/MasterFunctions.cpp @@ -368,7 +368,7 @@ inline void poet::ChemistryModule::MasterRecvPkgs(worker_list_t &w_list, std::vector cell_output( recv_buffer.begin() + this->prop_count * i, recv_buffer.begin() + this->prop_count * (i + 1)); - this->control_batch.push_back(std::move(cell_output)); + this->ctrl_batch.push_back(std::move(cell_output)); } break; } @@ -443,13 +443,11 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { MPI_INT); } - // if (control->shouldBcastFlags()) { ftype = CHEM_CTRL_FLAGS; PropagateFunctionType(ftype); uint32_t ctrl_flags = buildCtrlFlags(this->dht_enabled, this->interp_enabled, this->stab_enabled); ChemBCast(&ctrl_flags, 1, MPI_INT); - //} ftype = CHEM_WORK_LOOP; PropagateFunctionType(ftype); @@ -467,8 +465,6 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count, wp_sizes_vector.size()); - // this->mpi_surr_buffer.resize(mpi_buffer.size()); - /* setup local variables */ pkg_to_send = wp_sizes_vector.size(); pkg_to_recv = wp_sizes_vector.size(); @@ -523,39 +519,37 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { /* do master stuff */ - if (!this->control_batch.empty()) { - std::cout << "[Master] Processing " << this->control_batch.size() + if (!this->ctrl_batch.empty()) { + std::cout << "[Master] Processing " << this->ctrl_batch.size() << " control cells for comparison." << std::endl; - /* using mpi-buffer because we need cell-major layout*/ - std::vector> surrogate_batch; - surrogate_batch.reserve(this->control_batch.size()); + + std::vector> sur_batch; + sur_batch.reserve(this->ctrl_batch.size()); - for (const auto &element : this->control_batch) { + for (const auto &element : this->ctrl_batch) { + /* using mpi-buffer because we need cell-major layout*/ for (size_t i = 0; i < this->n_cells; i++) { uint32_t curr_cell_id = mpi_buffer[this->prop_count * i]; if (curr_cell_id == element[0]) { - std::vector surrogate_output( + std::vector sur_output( mpi_buffer.begin() + this->prop_count * i, mpi_buffer.begin() + this->prop_count * (i + 1)); - surrogate_batch.push_back(surrogate_output); + sur_batch.push_back(sur_output); break; } } } metrics_a = MPI_Wtime(); - control->computeErrorMetrics(this->control_batch, surrogate_batch, + control->computeErrorMetrics(this->ctrl_batch, sur_batch, prop_names, n_cells); - control->writeErrorMetrics(ctrl_file_out_dir, prop_names); - metrics_b = MPI_Wtime(); this->metrics_t += metrics_b - metrics_a; - // Clear for next control iteration - this->control_batch.clear(); + this->ctrl_batch.clear(); } /* start time measurement of master chemistry */ diff --git a/src/Control/ControlModule.cpp b/src/Control/ControlModule.cpp index f50d2300a..41d24eb8d 100644 --- a/src/Control/ControlModule.cpp +++ b/src/Control/ControlModule.cpp @@ -7,23 +7,21 @@ poet::ControlModule::ControlModule(const ControlConfig &config_) : config(config_) {} -void poet::ControlModule::beginIteration(ChemistryModule &chem, - const uint32_t &iter, +void poet::ControlModule::beginIteration(ChemistryModule &chem, uint32_t &iter, const bool &dht_enabled, const bool &interp_enabled) { - /* dht_enabled and inter_enabled are user settings set before startig the - * simulation*/ + global_iteration = iter; double prep_a, prep_b; prep_a = MPI_Wtime(); - global_iteration = iter; updateStabilizationPhase(chem, dht_enabled, interp_enabled); prep_b = MPI_Wtime(); this->prep_t += prep_b - prep_a; } +/* Disables dht and/or interp during stabilzation phase */ void poet::ControlModule::updateStabilizationPhase(ChemistryModule &chem, bool dht_enabled, bool interp_enabled) { @@ -31,16 +29,15 @@ void poet::ControlModule::updateStabilizationPhase(ChemistryModule &chem, if (disable_surr_counter > 0) { --disable_surr_counter; flush_request = false; - MSG("Rollback counter: " + std::to_string(disable_surr_counter)); + std::cout << "Rollback counter: " << std::to_string(disable_surr_counter) + << std::endl; } else { rollback_enabled = false; - MSG("Rollback stabilization complete, re-enabling surrogates"); + std::cout << "Rollback stabilization complete, re-enabling surrogate." + << std::endl; } } - bool prev_stab_state = chem.GetStabEnabled(); - - // user requested DHT/INTEP? keep them disabled but enable warmup-phase so if (global_iteration <= config.stab_interval || rollback_enabled) { chem.SetStabEnabled(true); chem.SetDhtEnabled(false); @@ -50,27 +47,19 @@ void poet::ControlModule::updateStabilizationPhase(ChemistryModule &chem, chem.SetDhtEnabled(dht_enabled); chem.SetInterpEnabled(interp_enabled); } - - // Mark that we need to broadcast flags if stab state changed - if (prev_stab_state != chem.GetStabEnabled()) { - stab_phase_ended = true; - } } void poet::ControlModule::writeCheckpoint(DiffusionModule &diffusion, uint32_t &iter, const std::string &out_dir) { - double w_check_a, w_check_b; - - w_check_a = MPI_Wtime(); if (global_iteration % config.checkpoint_interval == 0) { + double w_check_a = MPI_Wtime(); write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5", {.field = diffusion.getField(), .iteration = iter}); + double w_check_b = MPI_Wtime(); + this->w_check_t += w_check_b - w_check_a; + last_checkpoint_written = iter; } - w_check_b = MPI_Wtime(); - this->w_check_t += w_check_b - w_check_a; - - last_checkpoint_written = iter; } void poet::ControlModule::readCheckpoint(DiffusionModule &diffusion, @@ -90,18 +79,15 @@ void poet::ControlModule::readCheckpoint(DiffusionModule &diffusion, } void poet::ControlModule::writeErrorMetrics( - const std::string &out_dir, const std::vector &species) { - double stats_a, stats_b; + uint32_t &iter, const std::string &out_dir, + const std::vector &species) { - stats_a = MPI_Wtime(); + double stats_a = MPI_Wtime(); writeSpeciesStatsToCSV(metrics_history, species, out_dir, "species_overview.csv"); - // writeCellStatsToCSV(cell_metrics_history, species, out_dir, - // "cell_overview.csv"); write_metrics(cell_metrics_history, species, out_dir, "metrics_overview.hdf5"); - stats_b = MPI_Wtime(); - + double stats_b = MPI_Wtime(); this->stats_t += stats_b - stats_a; } @@ -109,19 +95,6 @@ uint32_t poet::ControlModule::getRollbackIter() { uint32_t last_iter = ((global_iteration - 1) / config.checkpoint_interval) * config.checkpoint_interval; - - /* - uint32_t rollback_iter = (last_iter <= last_checkpoint_written) - ? last_iter - : last_checkpoint_written; - - MSG("getRollbackIter: global_iteration=" + std::to_string(global_iteration) + - ", checkpoint_interval=" + std::to_string(config.checkpoint_interval) + - ", last_checkpoint_written=" + std::to_string(last_checkpoint_written) + - ", returning=" + std::to_string(last_checkpoint_written)); - - return last_checkpoint_written; - */ return last_iter; } @@ -129,55 +102,47 @@ std::optional poet::ControlModule::getRollbackTarget( const std::vector &species) { double r_check_a, r_check_b; - if (metrics_history.empty()) { - MSG("No error history yet, skipping rollback check."); - rollback_enabled = false; - return std::nullopt; - } - // Skip threshold checking if already in rollback/stabilization phase - if (rollback_enabled) { + /* Skip threshold checking if already in stabilization phase*/ + if (metrics_history.empty() || rollback_enabled) { return std::nullopt; } - const auto &s_mape = metrics_history.back().mape; + const auto &s_hist = metrics_history.back(); + /* skipping cell_id and id */ for (size_t sp_i = 2; sp_i < species.size(); sp_i++) { - // skip charge - if (s_mape[sp_i] == 0 || sp_i == 4) { + /* check bounds of threshold vector*/ + if (sp_i >= config.mape_threshold.size()) { + std::cerr << "Warning: No threshold defined for species " << species[sp_i] + << " at index " << std::to_string(sp_i) << std::endl; continue; } - if (s_mape[sp_i] > config.mape_threshold[sp_i]) { - if (last_checkpoint_written == 0) { - MSG(" Threshold exceeded but no checkpoint exists yet."); - return std::nullopt; - } - - const auto &c_mape = cell_metrics_history.back().mape; - const auto &c_id = cell_metrics_history.back().id; + if (s_hist.mape[sp_i] > config.mape_threshold[sp_i]) { + const auto &c_hist = cell_metrics_history.back(); auto max_it = std::max_element( - c_mape.begin(), c_mape.end(), + c_hist.mape.begin(), c_hist.mape.end(), [sp_i](const auto &a, const auto &b) { return a[sp_i] < b[sp_i]; }); - size_t max_idx = std::distance(c_mape.begin(), max_it); - uint32_t cell_id = c_id[max_idx]; + size_t max_idx = std::distance(c_hist.mape.begin(), max_it); + uint32_t cell_id = c_hist.id[max_idx]; double cell_mape = (*max_it)[sp_i]; rollback_enabled = true; flush_request = true; - MSG("Threshold exceeded for " + species[sp_i] + - " with species-level MAPE = " + std::to_string(s_mape[sp_i]) + - " exceeding threshold = " + - std::to_string(config.mape_threshold[sp_i]) + ". Worst cell: ID=" + - std::to_string(cell_id) + " with MAPE=" + std::to_string(cell_mape)); + std::cout << "Threshold exceeded for " << species[sp_i] + << " with species-level MAPE = " << std::to_string(s_hist.mape[sp_i]) + << " exceeding threshold = " + << std::to_string(config.mape_threshold[sp_i]) + << ". Worst cell: ID=" << std::to_string(cell_id) + << ", species= " << species[sp_i] + << " with MAPE=" << std::to_string(cell_mape) << std::endl; return getRollbackIter(); } } - rollback_enabled = false; - flush_request = false; return std::nullopt; } @@ -186,11 +151,6 @@ void poet::ControlModule::computeErrorMetrics( std::vector> &surrogate_values, const std::vector &species, const uint32_t size_per_prop) { - // Skip metric computation if already in rollback/stabilization phase - if (rollback_enabled) { - return; - } - const uint32_t n_cells = reference_values.size(); const uint32_t n_species = species.size(); const double ZERO_ABS = config.zero_abs; @@ -256,7 +216,11 @@ void poet::ControlModule::processCheckpoint( DiffusionModule &diffusion, uint32_t ¤t_iter, const std::string &out_dir, const std::vector &species) { - if (flush_request && rollback_count < 3) { + // Use max_rollbacks from config, default to 3 if not set + // uint32_t max_rollbacks = + // (config.max_rollbacks > 0) ? config.max_rollbacks : 3; + + if (flush_request /* && rollback_count < 3 */) { uint32_t target = getRollbackIter(); readCheckpoint(diffusion, current_iter, target, out_dir); @@ -264,21 +228,10 @@ void poet::ControlModule::processCheckpoint( rollback_count++; disable_surr_counter = config.stab_interval; - MSG("Restored checkpoint " + std::to_string(target) + - ", surrogates disabled for " + std::to_string(config.stab_interval)); + std::cout << "Restored checkpoint " << std::to_string(target) + << ", surrogate disabled for " + << std::to_string(config.stab_interval) << std::endl; } else { writeCheckpoint(diffusion, global_iteration, out_dir); } -} - -bool poet::ControlModule::shouldBcastFlags() { - if (global_iteration == 1) { - return true; - } - - if (stab_phase_ended) { - return true; - } - - return false; } \ No newline at end of file diff --git a/src/Control/ControlModule.hpp b/src/Control/ControlModule.hpp index a9ca0d907..209ab9a8e 100644 --- a/src/Control/ControlModule.hpp +++ b/src/Control/ControlModule.hpp @@ -19,7 +19,8 @@ class DiffusionModule; struct ControlConfig { uint32_t stab_interval = 0; - uint32_t checkpoint_interval = 0; + uint32_t checkpoint_interval = 0; // How often to write metrics files + //uint32_t max_rb = 0; // Maximum number of rollbacks allowed double zero_abs = 0.0; std::vector mape_threshold; }; @@ -54,10 +55,10 @@ class ControlModule { public: explicit ControlModule(const ControlConfig &config); - void beginIteration(ChemistryModule &chem, const uint32_t &iter, + void beginIteration(ChemistryModule &chem, uint32_t &iter, const bool &dht_enabled, const bool &interp_enaled); - void writeErrorMetrics(const std::string &out_dir, + void writeErrorMetrics(uint32_t &iter, const std::string &out_dir, const std::vector &species); std::optional getRollbackTarget(); diff --git a/src/poet.cpp b/src/poet.cpp index 5740c0f0b..51a9a4f8c 100644 --- a/src/poet.cpp +++ b/src/poet.cpp @@ -254,8 +254,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) { Rcpp::as(global_rt_setup->operator[]("checkpoint_interval")); params.stab_interval = Rcpp::as(global_rt_setup->operator[]("stab_interval")); - params.zero_abs = - Rcpp::as(global_rt_setup->operator[]("zero_abs")); + params.zero_abs = Rcpp::as(global_rt_setup->operator[]("zero_abs")); params.mape_threshold = Rcpp::as>( global_rt_setup->operator[]("mape_threshold")); params.ctrl_cell_ids = Rcpp::as>( @@ -415,9 +414,9 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, MSG("End of *coupling* iteration " + std::to_string(iter) + "/" + std::to_string(maxiter)); + control.writeErrorMetrics(iter, params.out_dir, chem.getField().GetProps()); control.processCheckpoint(diffusion, iter, params.out_dir, chem.getField().GetProps()); - // MSG(); } // END SIMULATION LOOP @@ -434,13 +433,22 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, Rcpp::List diffusion_profiling; diffusion_profiling["simtime"] = diffusion.getTransportTime(); - if (params.use_dht) { - chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits()); - chem_profiling["dht_evictions"] = Rcpp::wrap(chem.GetWorkerDHTEvictions()); - chem_profiling["dht_get_time"] = Rcpp::wrap(chem.GetWorkerDHTGetTimings()); - chem_profiling["dht_fill_time"] = - Rcpp::wrap(chem.GetWorkerDHTFillTimings()); - } + Rcpp::List ctrl_profiling; + ctrl_profiling["compute_metrics_master"] = chem.GetMasterCtrlMetricsTime(); + ctrl_profiling["unshuffle_field_master"] = chem.GetMasterUnshuffleTime(); + ctrl_profiling["w_checkpoint_master"] = control.getWriteCheckpointTime(); + ctrl_profiling["r_checkpoint_master"] = control.getReadCheckpointTime(); + ctrl_profiling["write_stats"] = control.getWriteMetricsTime(); + ctrl_profiling["ctrl_logic_master"] = control.getUpdateCtrlLogicTime(); + ctrl_profiling["recv_data_master"] = chem.GetMasterRecvCtrlDataTime(); + ctrl_profiling["worker"] = Rcpp::wrap(chem.GetWorkerControlTimings()); + + // if (params.use_dht) { + chem_profiling["dht_hits"] = Rcpp::wrap(chem.GetWorkerDHTHits()); + chem_profiling["dht_evictions"] = Rcpp::wrap(chem.GetWorkerDHTEvictions()); + chem_profiling["dht_get_time"] = Rcpp::wrap(chem.GetWorkerDHTGetTimings()); + chem_profiling["dht_fill_time"] = Rcpp::wrap(chem.GetWorkerDHTFillTimings()); + //} if (params.use_interp) { chem_profiling["interp_w"] = @@ -460,6 +468,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms, profiling["simtime"] = dSimTime; profiling["chemistry"] = chem_profiling; profiling["diffusion"] = diffusion_profiling; + profiling["control"] = ctrl_profiling; chem.MasterLoopBreak(); diff --git a/src/poet.hpp.in b/src/poet.hpp.in index fe7ccf79a..e9bd48d30 100644 --- a/src/poet.hpp.in +++ b/src/poet.hpp.in @@ -52,6 +52,7 @@ struct RuntimeParameters { bool print_progress = false; std::uint32_t stab_interval = 0; std::uint32_t checkpoint_interval = 0; + std::uint32_t max_rb = 0; double zero_abs = 0.0; std::vector mape_threshold; std::vector ctrl_cell_ids; diff --git a/util/data_evaluation/RFun_Eval_qs2.R b/util/data_evaluation/RFun_Eval_qs2.R new file mode 100644 index 000000000..a6ee27923 --- /dev/null +++ b/util/data_evaluation/RFun_Eval_qs2.R @@ -0,0 +1,452 @@ +## Simple library of functions to assess and visualize the results of the coupled simulations +## Modified to work with .qs2 files (qs2 package format) + +## Time-stamp: "Last modified 2025-11-25" + +require(qs2) ## for reading .qs2 files +require(stringr) + +# Note: RedModRphree, Rmufits, and Rcpp functions for DHT/PHT reading are kept +# but you'll need those packages only if you use ReadAllDHT/ReadAllPHT functions + +curdir <- dirname(sys.frame(1)$ofile) ##path.expand(".") + +print(paste("RFun_Eval_qs2.R is in ", curdir)) + +## ============================================================================ +## NEW: Functions for reading .qs2 simulation outputs +## ============================================================================ + +## function which reads all simulation results in a given directory (.qs2 format) +ReadRTSims_qs2 <- function(dir) { + pattern <- "^iter_.*\\.qs2$" + files_full <- list.files(dir, pattern = pattern, full.names = TRUE) + files_name <- list.files(dir, pattern = pattern, full.names = FALSE) + + if (length(files_full) == 0) { + warning(paste("No .qs2 files found in", dir, "with pattern", pattern)) + return(NULL) + } + + res <- lapply(files_full, qs2::qs_read) + names(res) <- gsub(".qs2", "", files_name, perl = TRUE) + + return(res[str_sort(names(res), numeric = TRUE)]) +} +## Read a single .qs2 file +ReadQS2 <- function(file) { + if (!file.exists(file)) { + stop(paste("File not found:", file)) + } + qs2::qs_read(file) +} + +## Extract chemistry field data from .qs2 iteration file +## Assumes structure similar to old .rds format with $C (chemistry) and $T (transport) +ExtractChemistry <- function(qs2_data) { + if ("C" %in% names(qs2_data)) { + return(qs2_data$C) + } else if (is.data.frame(qs2_data)) { + return(qs2_data) + } else { + warning("Could not find chemistry data in expected format") + return(qs2_data) + } +} + +## Extract transport field data from .qs2 iteration file +ExtractTransport <- function(qs2_data) { + if ("T" %in% names(qs2_data)) { + return(qs2_data$T) + } else { + warning("Could not find transport data in expected format") + return(NULL) + } +} + +## ============================================================================ +## ORIGINAL: DHT/PHT reading functions (kept for surrogate analysis) +## ============================================================================ + +# Only load these if needed (requires Rcpp compilation) +if (requireNamespace("Rcpp", quietly = TRUE) && file.exists(paste0(curdir, "/interpret_keys.cpp"))) { + library(Rcpp) + sourceCpp(file = paste0(curdir, "/interpret_keys.cpp")) + + # Wrapper around previous sourced Rcpp function + ConvertDHTKey <- function(value) { + rcpp_key_convert(value) + } + + ConvertToUInt64 <- function(double_data) { + rcpp_uint64_convert(double_data) + } +} else { + if (!requireNamespace("Rcpp", quietly = TRUE)) { + message("Note: Rcpp not available. DHT/PHT reading functions will not work.") + } + # Create dummy functions so the rest of the script doesn't break + ConvertDHTKey <- function(value) { + stop("Rcpp not available. Cannot convert DHT keys.") + } + ConvertToUInt64 <- function(double_data) { + stop("Rcpp not available. Cannot convert to UInt64.") + } +} + +## function which reads all successive DHT stored in a given directory +ReadAllDHT <- function(dir, new_scheme = TRUE) { + files_full <- list.files(dir, pattern="iter.*\\.dht$", full.names=TRUE) + files_name <- list.files(dir, pattern="iter.*\\.dht$", full.names=FALSE) + + if (length(files_full) == 0) { + warning(paste("No .dht files found in", dir)) + return(NULL) + } + + res <- lapply(files_full, ReadDHT, new_scheme = new_scheme) + names(res) <- gsub("\\.dht$","",files_name) + return(res) +} + +## function which reads one .dht file and gives a matrix +ReadDHT <- function(file, new_scheme = TRUE) { + conn <- file(file, "rb") ## open for reading in binary mode + if (!isSeekable(conn)) + stop("Connection not seekable") + + ## we first reposition ourselves to the end of the file... + tmp <- seek(conn, where=0, origin = "end") + ## ... and then back to the origin so to store the length in bytes + flen <- seek(conn, where=0, origin = "start") + + ## we read the first 2 integers (4 bytes each) containing dimensions in bytes + dims <- readBin(conn, what="integer", n=2) + + ## compute dimensions of the data + tots <- sum(dims) + ncol <- tots/8 + nrow <- (flen - 8)/tots ## 8 here is 2*sizeof("int") + buff <- readBin(conn, what="double", n=ncol*nrow) + ## close connection + close(conn) + res <- matrix(buff, nrow=nrow, ncol=ncol, byrow=TRUE) + + if (new_scheme) { + nkeys <- dims[1] / 8 + keys <- res[, 1:nkeys] + + conv <- apply(keys, 2, ConvertDHTKey) + res[, 1:nkeys] <- conv + } + + return(res) +} + +## function which reads all successive PHT stored in a given directory +ReadAllPHT <- function(dir, with_info = FALSE) { + files_full <- list.files(dir, pattern="iter.*\\.pht$", full.names=TRUE) + files_name <- list.files(dir, pattern="iter.*\\.pht$", full.names=FALSE) + + if (length(files_full) == 0) { + warning(paste("No .pht files found in", dir)) + return(NULL) + } + + res <- lapply(files_full, ReadPHT, with_info = with_info) + names(res) <- gsub("\\.pht$","",files_name) + return(res) +} + +## function which reads one .pht file and gives a matrix +ReadPHT <- function(file, with_info = FALSE) { + conn <- file(file, "rb") ## open for reading in binary mode + if (!isSeekable(conn)) + stop("Connection not seekable") + + ## we first reposition ourselves to the end of the file... + tmp <- seek(conn, where=0, origin = "end") + ## ... and then back to the origin so to store the length in bytes + flen <- seek(conn, where=0, origin = "start") + + ## we read the first 2 integers (4 bytes each) containing dimensions in bytes + dims <- readBin(conn, what="integer", n=2) + + ## compute dimensions of the data + tots <- sum(dims) + ncol <- tots/8 + nrow <- (flen - 8)/tots ## 8 here is 2*sizeof("int") + buff <- readBin(conn, what="double", n=ncol*nrow) + ## close connection + close(conn) + res <- matrix(buff, nrow=nrow, ncol=ncol, byrow=TRUE) + + nkeys <- dims[1] / 8 + keys <- res[, 1:nkeys] + + timesteps <- res[, nkeys + 1] + conv <- apply(keys, 2, ConvertDHTKey) + + ndata <- dims[2] / 8 + fill_rate <- ConvertToUInt64(res[, nkeys + 2]) + + buff <- c(conv, timesteps, fill_rate) + + if (with_info) { + ndata <- dims[2]/8 + visit_count <- ConvertToUInt64(res[, nkeys + ndata]) + buff <- c(buff, visit_count) + } + + res <- matrix(buff, nrow = nrow, byrow = FALSE) + + return(res) +} + +## ============================================================================ +## PLOTTING and ANALYSIS functions (work with both .rds and .qs2 data) +## ============================================================================ + +## Scatter plots of each variable in the iteration +PlotScatter <- function(sam1, sam2, which=NULL, labs=c("NO DHT", "DHT"), pch=".", cols=3, ...) { + if ((!is.data.frame(sam1)) & ("T" %in% names(sam1))) + sam1 <- sam1$C + if ((!is.data.frame(sam2)) & ("T" %in% names(sam2))) + sam2 <- sam2$C + if (is.numeric(which)) + inds <- which + else if (is.character(which)) + inds <- match(which, colnames(sam1)) + else if (is.null(which)) + inds <- seq_along(colnames(sam1)) + + rows <- ceiling(length(inds) / cols) + par(mfrow=c(rows, cols)) + a <- lapply(inds, function(x) { + plot(sam1[,x], sam2[,x], main=colnames(sam1)[x], xlab=labs[1], ylab=labs[2], pch=pch, col="red", ...) + abline(0,1, col="grey", lwd=1.5) + }) + invisible() +} + +##### Some metrics for relative comparison + +## Root Mean Square Error +RMSE <- function(pred, obs) + sqrt(mean((pred - obs)^2, na.rm = TRUE)) + +## Using range as norm +RranRMSE <- function(pred, obs) + sqrt(mean((pred - obs)^2, na.rm = TRUE))/abs(max(pred, na.rm = TRUE) - min(pred, na.rm = TRUE)) + +## Using max val as norm +RmaxRMSE <- function(pred, obs) + sqrt(mean((pred - obs)^2, na.rm = TRUE))/abs(max(pred, na.rm = TRUE)) + +## Using sd as norm +RsdRMSE <- function(pred, obs) + sqrt(mean((pred - obs)^2, na.rm = TRUE))/sd(pred, na.rm = TRUE) + +## Using mean as norm +RmeanRMSE <- function(pred, obs) + sqrt(mean((pred - obs)^2, na.rm = TRUE))/mean(pred, na.rm = TRUE) + +## Using mean as norm +RAEmax <- function(pred, obs) + mean(abs(pred - obs), na.rm = TRUE)/max(pred, na.rm = TRUE) + +## Max absolute error +MAE <- function(pred, obs) + max(abs(pred - obs), na.rm = TRUE) + +## Mean Absolute Percentage Error +MAPE <- function(pred, obs) + mean(abs((obs - pred) / obs) * 100, na.rm = TRUE) + +## workhorse function for ComputeErrors and its use with mapply +AppliedFun <- function(a, b, .fun) { + # Extract chemistry data if needed + if (!is.data.frame(a) && "C" %in% names(a)) a <- a$C + if (!is.data.frame(b) && "C" %in% names(b)) b <- b$C + + mapply(.fun, as.list(a), as.list(b)) +} + +## Compute the diffs between two simulation, iter by iter, +## with a given metric (passed in form of function name to this function) +ComputeErrors <- function(sim1, sim2, FUN=RMSE) { + if (length(sim1)!= length(sim2)) { + cat("The simulations do not have the same length, subsetting to the shortest\n") + a <- min(length(sim1), length(sim2)) + sim1 <- sim1[1:a] + sim2 <- sim2[1:a] + } + if (!is.function(match.fun(FUN))) { + stop("Invalid function\n") + } + + t(mapply(AppliedFun, sim1, sim2, MoreArgs=list(.fun=FUN))) +} + +## Function to display the error progress between 2 simulations +ErrorProgress <- function(mat, ignore, colors, metric, ...) { + if (is.null(mat)) { + stop("Cannot plot: matrix is NULL") + } + + # Convert to matrix if it's a vector or data frame + if (is.vector(mat)) { + stop("Cannot plot: input is a vector (need at least 2 columns). Check that your data has multiple columns.") + } + if (is.data.frame(mat)) { + mat <- as.matrix(mat) + } + + if (nrow(mat) == 0 || ncol(mat) == 0) { + stop("Cannot plot: matrix is empty") + } + + if (missing(colors)) + colors <- sample(rainbow(ncol(mat))) + + if (missing(metric)) + metric <- "Metric" + + ## if the optional argument "ignore" (a character vector) is + ## passed, we remove the matching column names + if (!missing(ignore)) { + to_remove <- match(ignore, colnames(mat)) + to_remove <- to_remove[!is.na(to_remove)] # Remove NAs + if (length(to_remove) > 0) { + mat <- mat[, -to_remove, drop = FALSE] + colors <- colors[-to_remove] + } + } + + yc <- mat[nrow(mat),] + par(mar=c(5,4,2,8)) + matplot(mat, type="l", lty=1, lwd=2, col=colors, xlab="iteration", ylab=metric, ...) + mtext(colnames(mat), side = 4, line = 0.5, outer = FALSE, at = yc, adj = 0, col = colors, las=2, cex=0.7) +} + +## Function which exports all simulations to ParaView's .vtu +## Requires package RcppVTK +ExportToParaview <- function(vtu, nameout, results) { + if (!requireNamespace("RcppVTK", quietly = TRUE)) { + stop("Package RcppVTK is required for this function") + } + require(RcppVTK) + + n <- length(results) + vars <- colnames(results[[1]]) + ## strip eventually present ".vtu" from nameout + nameout <- sub(".vtu", "", nameout, fixed=TRUE) + namesteps <- paste0(nameout, ".", sprintf("%04d",seq(1,n)), ".vtu") + for (step in seq_along(results)) { + file.copy(from=vtu, to=namesteps[step], overwrite = TRUE) + cat(paste("Saving step ", step, " in file ", namesteps[step], "\n")) + ret <- ExportMatrixToVTU (fin=vtu, fout=namesteps[step], names=colnames(results[[step]]), mat=results[[step]]) + } + invisible(ret) +} + + +## Version of Rmufits::PlotCartCellData with the ability to fix the +## "breaks" for color coding of 2D simulations +Plot2DCellData <- function (data, grid, nx, ny, contour = TRUE, + nlevels = 12, breaks, palette = "heat.colors", + rev.palette = TRUE, scale = TRUE, plot.axes=TRUE, ...) { + if (!missing(grid)) { + xc <- unique(sort(grid$cell$XCOORD)) + yc <- unique(sort(grid$cell$YCOORD)) + nx <- length(xc) + ny <- length(yc) + if (!length(data) == nx * ny) + stop("Wrong nx, ny or grid") + } else { + xc <- seq(1, nx) + yc <- seq(1, ny) + } + z <- matrix(round(data, 6), ncol = nx, nrow = ny, byrow = TRUE) + pp <- t(z[rev(seq(1, nrow(z))), ]) + + if (missing(breaks)) { + breaks <- pretty(data, n = nlevels) + } + + breakslen <- length(breaks) + colors <- do.call(palette, list(n = breakslen - 1)) + if (rev.palette) + colors <- rev(colors) + if (scale) { + par(mfrow = c(1, 2)) + nf <- layout(matrix(c(1, 2), 1, 2, byrow = TRUE), widths = c(4, + 1)) + } + par(las = 1, mar = c(5, 5, 3, 1)) + image(xc, yc, pp, xlab = "X [m]", ylab = "Y[m]", las = 1, asp = 1, + breaks = breaks, col = colors, axes = FALSE, ann=plot.axes, + ...) + + if (plot.axes) { + axis(1) + axis(2) + } + if (contour) + contour(unique(sort(xc)), unique(sort(yc)), pp, breaks = breaks, + add = TRUE) + if (scale) { + par(las = 1, mar = c(5, 1, 5, 5)) + if (requireNamespace("Rmufits", quietly = TRUE)) { + Rmufits::PlotImageScale(data, breaks = breaks, add.axis = FALSE, + axis.pos = 4, col = colors) + } + axis(4, at = breaks) + } + invisible(pp) +} + +PlotAsMP4 <- function(data, nx, ny, to_plot, out_dir, name, + contour = FALSE, scale = FALSE, framerate = 30) { + sort_data <- data[str_sort(names(data), numeric = TRUE)] + plot_data <- lapply(sort_data, function(x) { + if (!is.data.frame(x) && "C" %in% names(x)) { + return(x$C[[to_plot]]) + } else { + return(x[[to_plot]]) + } + }) + pad_size <- ceiling(log10(length(plot_data))) + + dir.create(out_dir, showWarnings = FALSE) + output_files <- paste0(out_dir, "/", name, "_%0", pad_size, "d.png") + output_mp4 <- paste0(out_dir, "/", name, ".mp4") + + png(output_files, + width = 297, height = 210, units = "mm", + res = 100 + ) + + for (i in 1:length(plot_data)) { + if (requireNamespace("Rmufits", quietly = TRUE)) { + Rmufits::PlotCartCellData(plot_data[[i]], nx = nx, ny = ny, contour = contour, scale = scale) + } else { + Plot2DCellData(plot_data[[i]], nx = nx, ny = ny, contour = contour, scale = scale) + } + } + dev.off() + + ffmpeg_command <- paste( + "ffmpeg -y -framerate", framerate, "-i", output_files, + "-c:v libx264 -crf 22", output_mp4 + ) + + system(ffmpeg_command) + message(paste("Created video:", output_mp4)) +} + +cat("\n=== RFun_Eval_qs2.R loaded successfully ===\n") +cat("New functions for .qs2 files:\n") +cat(" - ReadRTSims_qs2(dir) : Read all iteration .qs2 files\n") +cat(" - ReadQS2(file) : Read single .qs2 file\n") +cat("All other functions work as before!\n\n") diff --git a/util/data_evaluation/run_vis.R b/util/data_evaluation/run_vis.R new file mode 100644 index 000000000..bfcd876d8 --- /dev/null +++ b/util/data_evaluation/run_vis.R @@ -0,0 +1,50 @@ +# Load the new functions +source("/mnt/beegfs/home/rastogi/poet/util/data_evaluation/RFun_Eval.R") + +# Set base path +base_dir <- "/mnt/beegfs/home/rastogi/poet/bin" + +sim1 <- ReadRTSims(file.path(base_dir, "proto2_eps01_no_rb_v2")) +sim2 <- ReadRTSims(file.path(base_dir, "proto2_eps0035_no_rb_v2")) + + +# ======================================== +# Compare two simulations +# ======================================== +rmse_errors <- ComputeErrors(sim1, sim2, FUN = RMSE) +mape_errors <- ComputeErrors(sim1, sim2, FUN = MAPE) + +# Print summary +cat("RMSE errors computed for", ncol(rmse_errors), "variables\n") +cat("MAPE errors computed for", ncol(mape_errors), "variables\n") +cat("Number of iterations compared:", nrow(rmse_errors), "\n\n") + +# Set output path explicitly +output_pdf <- file.path(base_dir, "comparison_plots.pdf") +cat("Saving plots to:", output_pdf, "\n") + +# Save plots to PDF +pdf(output_pdf, width = 10, height = 6) + +# Plot error progression +cat("Creating ErrorProgress plot...\n") +ErrorProgress(rmse_errors, ignore = c("Charge"), metric = "RMSE") + +# Scatter plot for specific iteration (if that iteration exists) +if (length(sim1) >= 10 && length(sim2) >= 10) { + cat("Creating scatter plot for iteration 10...\n") + PlotScatter(sim1[[10]], sim2[[10]], + labs = c("Proto2 Eps=0.01%", "Proto2 Eps=0.0035%"), + which = c("Ca", "Mg", "Cl", "C")) +} else { + cat("Not enough iterations for scatter plot\n") +} + +dev.off() +cat("Plots saved successfully to:", output_pdf, "\n") + +# ======================================== +# DHT/PHT analysis (if you have snapshot files) +# ======================================== +#dht_snaps <- ReadAllDHT("poet/bin/proto1_only_interp") +#pht_snaps <- ReadAllPHT("poet/bin/proto1_only_interp") \ No newline at end of file