Compare commits

...

2 Commits

Author SHA1 Message Date
rastogi
5c02308901 added infinity check to MAPE calculation 2025-11-23 13:40:54 +01:00
rastogi
10e375831b Updated .gitignore 2025-11-23 13:39:55 +01:00
9 changed files with 1756 additions and 4 deletions

9
.gitignore vendored
View File

@ -151,3 +151,12 @@ share/
lib/
include/
/.ai/
!bin/dolo_fgcs_3_rt.R
!bin/dolo_fgcs.pqi
!bin/phreeqc_kin.dat
!bin/dolo_fgcs_3.R
!bin/run_poet.sh
!bin/plot_metrics.R
!bin/sum_time.R

48
bin/dolo_fgcs.pqi Normal file
View File

@ -0,0 +1,48 @@
SOLUTION 1
units mol/kgw
water 1
temperature 25
pH 7
pe 4
PURE 1
Calcite 0.0 1
END
RUN_CELLS
-cells 1
END
COPY solution 1 2
#PURE 2
# O2g -0.1675 10
KINETICS 2
Calcite
-m 0.00207
-parms 0.05
-tol 1e-10
Dolomite
-m 0.0
-parms 0.01
-tol 1e-10
END
SOLUTION 3
units mol/kgw
water 1
temp 25
Mg 0.001
Cl 0.002
END
SOLUTION 4
units mol/kgw
water 1
temp 25
Mg 0.002
Cl 0.004
END
RUN_CELLS
-cells 2-4
END

135
bin/dolo_fgcs_3.R Normal file
View File

@ -0,0 +1,135 @@
rows <- 400
cols <- 400
grid_def <- matrix(2, nrow = rows, ncol = cols)
# Define grid configuration for POET model
grid_setup <- list(
pqc_in_file = "./dolo_fgcs.pqi",
pqc_db_file = "./phreeqc_kin.dat", # Path to the database file for Phreeqc
grid_def = grid_def, # Definition of the grid, containing IDs according to the Phreeqc input script
grid_size = c(5, 5), # Size of the grid in meters
constant_cells = c() # IDs of cells with constant concentration
)
bound_def_we <- list(
"type" = rep("constant", rows),
"sol_id" = rep(1, rows),
"cell" = seq(1, rows)
)
bound_def_ns <- list(
"type" = rep("constant", cols),
"sol_id" = rep(1, cols),
"cell" = seq(1, cols)
)
diffusion_setup <- list(
boundaries = list(
"W" = bound_def_we,
"E" = bound_def_we,
"N" = bound_def_ns,
"S" = bound_def_ns
),
inner_boundaries = list(
"row" = floor(rows / 2),
"col" = floor(cols / 2),
"sol_id" = c(3)
),
alpha_x = 1e-6,
alpha_y = 1e-6
)
check_sign_cal_dol_dht <- function(old, new) {
if ((old["Calcite"] == 0) != (new["Calcite"] == 0)) {
return(TRUE)
}
if ((old["Dolomite"] == 0) != (new["Dolomite"] == 0)) {
return(TRUE)
}
return(FALSE)
}
check_sign_cal_dol_interp <- function(to_interp, data_set) {
dht_species <- c(
"H" = 3,
"O" = 3,
"C" = 6,
"Ca" = 6,
"Cl" = 3,
"Mg" = 5,
"Calcite" = 4,
"Dolomite" = 4
)
data_set <- as.data.frame(do.call(rbind, data_set), check.names = FALSE, optional = TRUE)
names(data_set) <- names(dht_species)
cal <- (data_set$Calcite == 0) == (to_interp["Calcite"] == 0)
dol <- (data_set$Dolomite == 0) == (to_interp["Dolomite"] == 0)
cal_dol_same_sig <- cal == dol
return(rev(which(!cal_dol_same_sig)))
}
check_neg_cal_dol <- function(result) {
neg_sign <- (result["Calcite"] < 0) || (result["Dolomite"] < 0)
return(neg_sign)
}
# Optional when using Interpolation (example with less key species and custom
# significant digits)
pht_species <- c(
"C" = 3,
"Ca" = 3,
"Mg" = 3,
"Cl" = 3,
"Calcite" = 3,
"Dolomite" = 3
)
dht_species <- c(
"H" = 3,
"O" = 3,
"C" = 6,
"Ca" = 6,
"Cl" = 3,
"Mg" = 5,
"Calcite" = 4,
"Dolomite" = 4)
chemistry_setup <- list(
dht_species = dht_species,
pht_species = pht_species,
hooks = list(
dht_fill = check_sign_cal_dol_dht,
interp_pre = check_sign_cal_dol_interp,
interp_post = check_neg_cal_dol
)
)
# Define a setup list for simulation configuration
setup <- list(
Grid = grid_setup, # Parameters related to the grid structure
Diffusion = diffusion_setup, # Parameters related to the diffusion process
Chemistry = chemistry_setup # Parameters related to the chemistry process
)
iterations <- 15000
dt <- 200
checkpoint_interval <- 100
control_interval <- 100
mape_threshold <- rep(0.1, 13)
mape_threshold[5] <- 1 #Charge
out_save <- seq(1000, iterations, by = 1000)
#out_save = c(seq(1, 10), seq(10, 100, by= 10), seq(200, iterations, by=100))
list(
timesteps = rep(dt, iterations),
store_result = TRUE,
out_save = out_save,
checkpoint_interval = checkpoint_interval,
control_interval = control_interval,
mape_threshold = mape_threshold
)

21
bin/dolo_fgcs_3_rt.R Normal file
View File

@ -0,0 +1,21 @@
iterations <- 15000
dt <- 200
checkpoint_interval <- 100
control_interval <- 100
mape_threshold <- rep(0.0035, 13)
zero_abs <- 0.0
#mape_threshold[5] <- 1 #Charge
#ctrl_cell_ids <- seq(0, (400*400)/2 - 1, by = 401)
#out_save <- seq(500, iterations, by = 500)
#out_save = c(seq(1, 10), seq(10, 100, by= 10), seq(200, iterations, by=100))
list(
timesteps = rep(dt, iterations),
store_result = FALSE,
#out_save = out_save,
checkpoint_interval = checkpoint_interval,
control_interval = control_interval,
mape_threshold = mape_threshold,
zero_abs = zero_abs
)

1307
bin/phreeqc_kin.dat Normal file

File diff suppressed because it is too large Load Diff

142
bin/plot_metrics.R Normal file
View File

@ -0,0 +1,142 @@
#!/usr/bin/env Rscript
suppressPackageStartupMessages({library(dplyr); library(ggplot2); library(tidyr)})
args <- commandArgs(trailingOnly = TRUE)
if (length(args) < 1) stop("Usage: Rscript plot_mape_stats.R <stats_overview_file1> [stats_overview_file2] ...")
cat("Reading", length(args), "stats file(s)...\n")
# Process all input files
all_data <- lapply(args, function(stats_file) {
if (!file.exists(stats_file)) {
warning("File not found: ", stats_file)
return(NULL)
}
cat(" -", stats_file, "\n")
lines <- readLines(stats_file)
data_lines <- lines[!grepl("^-+$", lines) & nchar(lines) > 0]
parsed <- lapply(data_lines, function(line) {
parts <- strsplit(trimws(line), "\\s+")[[1]]
if (length(parts) >= 5) {
data.frame(
Iteration = as.numeric(parts[1]),
Rollback = as.numeric(parts[2]),
Species = parts[3],
MAPE = as.numeric(parts[4]),
RRMSE = as.numeric(parts[5]),
stringsAsFactors = FALSE
)
}
})
df <- bind_rows(parsed) %>% filter(!is.na(Iteration))
species_list <- c("H", "O", "C", "Ca", "Cl", "Mg", "Calcite", "Dolomite")
#species_list <- "Dolomite"
df_filtered <- df %>%
filter(Species %in% species_list) %>%
group_by(Iteration) %>%
summarise(
MedianMAPE = median(MAPE, na.rm = TRUE),
MaxMAPE = max(MAPE, na.rm = TRUE),
Rollback = first(Rollback),
.groups = "drop"
) %>%
filter(Iteration %% 100 == 0) %>%
mutate(Folder = basename(dirname(stats_file)))
# Detect rollback changes
df_filtered <- df_filtered %>%
arrange(Iteration) %>%
mutate(RollbackChange = Rollback != lag(Rollback, default = first(Rollback)))
return(df_filtered)
})
combined_data <- bind_rows(all_data) %>%
filter(Iteration <= 3000) %>%
filter(is.finite(MedianMAPE) & MedianMAPE > 0) %>%
filter(is.finite(MaxMAPE) & MaxMAPE > 0)
# Identify rollback transitions for each folder
rollback_points <- combined_data %>%
filter(RollbackChange == TRUE) %>%
select(Folder, Iteration, Rollback)
cat("\nData summary:\n")
print(head(combined_data))
cat("\nLegend:", unique(combined_data$Folder), "\n")
cat("\nRollback transitions detected:\n")
print(rollback_points)
# A consistent style for both plots
pretty_theme <- theme_minimal(base_size = 14) +
theme(
plot.title = element_text(face = "bold", size = 16, hjust = 0.5),
axis.title = element_text(face = "bold"),
legend.position = "right",
panel.grid.minor = element_blank(),
panel.grid.major.x = element_line(color = "grey85"),
panel.grid.major.y = element_line(color = "grey85"),
axis.line = element_line(linewidth = 0.8, colour = "black"),
axis.ticks = element_line(colour = "black")
)
# Determine nice log-scale breaks (1e-1, 1e-2, 1e-3, etc.)
log_breaks <- 10^seq(max(-6, floor(log10(min(combined_data$MedianMAPE, combined_data$MaxMAPE, na.rm = TRUE)))),
ceiling(log10(max(combined_data$MedianMAPE, combined_data$MaxMAPE, na.rm = TRUE))),
by = 1)
# Common log label formatter
log_labels <- function(x) sprintf("1e%d", log10(x))
# Plot Median MAPE
p1 <- ggplot(combined_data, aes(x = Iteration, y = MedianMAPE, color = Folder)) +
geom_line(linewidth = 1) +
geom_point(size = 2) +
geom_vline(data = rollback_points, aes(xintercept = Iteration, color = Folder),
linetype = "dashed", alpha = 0.6, linewidth = 0.8) +
scale_x_continuous(breaks = seq(0, max(combined_data$Iteration), by = 1000)) +
scale_y_log10(breaks = log_breaks, labels = log_labels) +
labs(
title = "Median MAPE Across H, O, C, Ca, Cl, Mg, Calcite, Dolomite",
x = "Iteration",
y = "Median MAPE",
color = "Legend"
) +
pretty_theme
# Plot Max MAPE
p2 <- ggplot(combined_data, aes(x = Iteration, y = MaxMAPE, color = Folder)) +
geom_line(linewidth = 1) +
geom_point(size = 2) +
geom_vline(data = rollback_points, aes(xintercept = Iteration, color = Folder),
linetype = "dashed", alpha = 0.6, linewidth = 0.8) +
scale_x_continuous(breaks = seq(0, max(combined_data$Iteration), by = 1000)) +
scale_y_log10(breaks = log_breaks, labels = log_labels, limits = c(1e-5, NA)) +
labs(
title = "Max MAPE Across H, O, C, Ca, Cl, Mg, Calcite, Dolomite",
x = "Iteration",
y = "Max MAPE",
color = "Legend"
) +
pretty_theme
# Save plots
script_dir <- dirname(sub("--file=", "", grep("--file=", commandArgs(trailingOnly = FALSE), value = TRUE)))
if (length(script_dir) == 0 || script_dir == "") script_dir <- getwd()
ggsave(file.path(script_dir, "median_mape.pdf"), p1, width = 10, height = 6)
ggsave(file.path(script_dir, "max_mape.pdf"), p2, width = 10, height = 6)
cat("\nPlots saved:\n")
cat(" -", file.path(script_dir, "median_mape.pdf"), "\n")
cat(" -", file.path(script_dir, "max_mape.pdf"), "\n")
# Also save data
write.csv(combined_data, file.path(script_dir, "mape_summary.csv"), row.names = FALSE)
cat(" -", file.path(script_dir, "mape_summary.csv"), "\n")

19
bin/run_poet.sh Normal file
View File

@ -0,0 +1,19 @@
#!/bin/bash
#SBATCH --job-name=proto1_only_interp_zeroabs
#SBATCH --output=proto1_only_interp_zeroabs_%j.out
#SBATCH --error=proto1_only_interp_zeroabs_%j.err
#SBATCH --partition=long
#SBATCH --nodes=6
#SBATCH --ntasks-per-node=24
#SBATCH --ntasks=144
#SBATCH --exclusive
#SBATCH --time=3-00:00:00
source /etc/profile.d/modules.sh
module purge
module load cmake gcc openmpi
#mpirun -n 144 ./poet dolo_fgcs_3.R dolo_fgcs_3.qs2 dolo_only_pqc
mpirun -n 144 ./poet --interp dolo_fgcs_3_rt.R dolo_fgcs_3.qs2 proto1_only_interp_zeroabs
#mpirun -n 144 ./poet --interp barite_fgcs_4_new/barite_fgcs_4_new_rt.R barite_fgcs_4_new/barite_fgcs_4_new.qs2 barite

61
bin/sum_time.R Normal file
View File

@ -0,0 +1,61 @@
#!/usr/bin/env Rscript
# Summarize timing vectors from timings.qs2 files.
# Usage: Rscript summarize_timings.R file1.qs2 [file2.qs2 ...]
if (!requireNamespace("qs2", quietly = TRUE)) {
stop("Package 'qs2' not installed. Install with: remotes::install_github('qs2io/qs2')")
}
args <- commandArgs(trailingOnly = TRUE)
if (length(args) < 1) {
stop("Usage: Rscript summarize_timings.R timings.qs2 [more.qs2 ...]")
}
chem_vectors <- c(
"idle_worker", "phreeqc_time", "dht_get_time", "dht_fill_time",
"interp_w", "interp_r", "interp_g", "interp_fc"
)
summaries <- lapply(args, function(f) {
if (!file.exists(f)) {
warning("File not found: ", f)
return(NULL)
}
obj <- qs2::qs_read(f)
chem <- obj$chemistry
ctrl <- obj$control_loop
# ---- sum chemistry vectors, round 2 digits ----
chem_sums <- sapply(chem_vectors, function(v) {
x <- chem[[v]]
if (!is.numeric(x)) return(NA_real_)
round(sum(x, na.rm = TRUE), 2)
})
# ---- sum worker ----
worker_sum <- {
x <- ctrl$worker
if (!is.numeric(x)) NA_real_ else round(sum(x, na.rm = TRUE), 2)
}
# ---- assemble a long-format table ----
data.frame(
file = basename(f),
vector = c(chem_vectors, "worker"),
sum = c(chem_sums, worker_sum),
stringsAsFactors = FALSE
)
})
summaries <- do.call(rbind, summaries)
# ---- print result ----
cat("\nTiming sums (rounded to 2 digits):\n\n")
print(summaries, row.names = FALSE)
# ---- save CSV ----
write.csv(summaries, "timings_summary.csv", row.names = FALSE)
cat("\nSaved summary to timings_summary.csv\n")

View File

@ -164,8 +164,13 @@ void poet::ControlModule::computeErrorMetrics(
if (std::isnan(ref_value) || std::isnan(sur_value)) {
continue;
}
if (std::abs(ref_value) < ZERO_ABS) {
if (std::abs(sur_value) >= ZERO_ABS) {
if (!std::isfinite(ref_value) || !std::isfinite(sur_value)) {
continue;
}
if (std::abs(ref_value) == ZERO_ABS) {
if (std::abs(sur_value) != ZERO_ABS) {
err_sum += 1.0;
sqr_err_sum += 1.0;
}
@ -173,12 +178,17 @@ void poet::ControlModule::computeErrorMetrics(
// Both zero: skip
else {
double alpha = 1.0 - (sur_value / ref_value);
if (!std::isfinite(alpha)) {
continue; // protects against inf/NaN due to extreme values
}
err_sum += std::abs(alpha);
sqr_err_sum += alpha * alpha;
}
}
metrics.mape[i] = 100.0 * (err_sum / size_per_prop);
metrics.rrmse[i] = std::sqrt(sqr_err_sum / size_per_prop);
metrics.mape[i] = 100.0 * (err_sum / static_cast<double>(size_per_prop));
metrics.rrmse[i] =
std::sqrt(sqr_err_sum / static_cast<double>(size_per_prop));
}
metrics_history.push_back(metrics);
}