mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-13 03:18:23 +01:00
Compare commits
15 Commits
80c51a14ae
...
53f1416b96
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53f1416b96 | ||
|
|
9be9ba2818 | ||
|
|
a8bf09a730 | ||
|
|
950793b73e | ||
|
|
672e1801cd | ||
|
|
0d457e1f2a | ||
|
|
67053fc662 | ||
|
|
0a5f7010fe | ||
|
|
bb25f6b449 | ||
|
|
5c5c328b0b | ||
|
|
3be8cc1cb4 | ||
|
|
9a329da4b5 | ||
|
|
49135615d1 | ||
|
|
c7d1fc152c | ||
|
|
d825f33b4f |
@ -1,7 +1,6 @@
|
||||
function(ADD_BENCH_TARGET TARGET POET_BENCH_LIST RT_FILES OUT_PATH)
|
||||
function(ADD_BENCH_TARGET TARGET POET_BENCH_LIST RT_FILES WEIGHT_FILES OUT_PATH)
|
||||
set(bench_install_dir share/poet/${OUT_PATH})
|
||||
|
||||
# create empty list
|
||||
set(OUT_FILES_LIST "")
|
||||
|
||||
foreach(BENCH_FILE ${${POET_BENCH_LIST}})
|
||||
@ -28,16 +27,15 @@ function(ADD_BENCH_TARGET TARGET POET_BENCH_LIST RT_FILES OUT_PATH)
|
||||
|
||||
install(FILES ${OUT_FILES_LIST} DESTINATION ${bench_install_dir})
|
||||
|
||||
# install all ADD_FILES to the same location
|
||||
install(FILES ${${RT_FILES}} DESTINATION ${bench_install_dir})
|
||||
|
||||
install(FILES ${${WEIGHT_FILES}} DESTINATION ${bench_install_dir})
|
||||
endfunction()
|
||||
|
||||
|
||||
# define target name
|
||||
set(BENCHTARGET benchmarks)
|
||||
|
||||
add_custom_target(${BENCHTARGET} ALL)
|
||||
|
||||
add_subdirectory(barite)
|
||||
add_subdirectory(dolo)
|
||||
add_subdirectory(surfex)
|
||||
add_subdirectory(dolo)
|
||||
@ -1,19 +1,21 @@
|
||||
# Create a list of files
|
||||
set(bench_files
|
||||
barite_200.R
|
||||
barite_het.R
|
||||
)
|
||||
|
||||
set(runtime_files
|
||||
barite_200_rt.R
|
||||
barite_het_rt.R
|
||||
)
|
||||
|
||||
set(weight_files
|
||||
barite_trained.weights.h5)
|
||||
|
||||
# add_custom_target(barite_bench DEPENDS ${bench_files} ${runtime_files})
|
||||
|
||||
ADD_BENCH_TARGET(barite_bench
|
||||
bench_files
|
||||
runtime_files
|
||||
runtime_files
|
||||
weight_files
|
||||
"barite"
|
||||
)
|
||||
|
||||
|
||||
@ -1,48 +1,82 @@
|
||||
## load a pretrained model from tensorflow file
|
||||
## Use the global variable "ai_surrogate_base_path" when using file paths
|
||||
## relative to the input script
|
||||
initiate_model <- function() {
|
||||
init_model <- normalizePath(paste0(ai_surrogate_base_path,
|
||||
"model_min_max_float64.keras"))
|
||||
return(load_model_tf(init_model))
|
||||
}
|
||||
|
||||
scale_min_max <- function(x, min, max, backtransform) {
|
||||
if (backtransform) {
|
||||
return((x * (max - min)) + min)
|
||||
} else {
|
||||
return((x - min) / (max - min))
|
||||
}
|
||||
if (backtransform) {
|
||||
return((x * (max - min)) + min)
|
||||
} else {
|
||||
return((x - min) / (max - min))
|
||||
}
|
||||
}
|
||||
|
||||
preprocess <- function(df, backtransform = FALSE, outputs = FALSE) {
|
||||
minmax_file <- normalizePath(paste0(ai_surrogate_base_path,
|
||||
"min_max_bounds.rds"))
|
||||
global_minmax <- readRDS(minmax_file)
|
||||
for (column in colnames(df)) {
|
||||
df[column] <- lapply(df[column],
|
||||
scale_min_max,
|
||||
global_minmax$min[column],
|
||||
global_minmax$max[column],
|
||||
backtransform)
|
||||
}
|
||||
return(df)
|
||||
scale_standardizer <- function(x, mean, scale, backtransform) {
|
||||
if(backtransform){
|
||||
return(x * scale + mean)
|
||||
}
|
||||
else{
|
||||
return((x-mean) / scale)
|
||||
}
|
||||
}
|
||||
|
||||
standard <- list(mean = c(H = 111.01243361730982, O= 55.50673140754027, Ba= 0.0016161137065825058,
|
||||
Cl= 0.0534503766678322, S=0.00012864849674669584, Sr=0.0252377348949622,
|
||||
Barite_kin=0.05292312117000998, Celestite_kin=0.9475491659328229),
|
||||
scale = c(H=1.0, O=0.00048139729680698453, Ba=0.008945717576237102, Cl=0.03587363709464328,
|
||||
S=0.00012035100591827131, Sr=0.01523052668095922, Barite_kin=0.21668648247230615,
|
||||
Celestite_kin=0.21639449682671968))
|
||||
|
||||
minmax <- list(min = c(H = 111.012433592824, O = 55.5062185549492, Charge = -3.1028354471876e-08,
|
||||
Ba = 1.87312878574393e-141, Cl = 0, `S(6)` = 4.24227510643685e-07,
|
||||
Sr = 0.00049382996130541, Barite = 0.000999542409828586, Celestite = 0.244801877115968),
|
||||
max = c(H = 111.012433679682, O = 55.5087003521685, Charge = 5.27666636082035e-07,
|
||||
Ba = 0.0908849779513762, Cl = 0.195697626449355, `S(6)` = 0.000620774752665846,
|
||||
Sr = 0.0558680070692722, Barite = 0.756779139057097, Celestite = 1.00075422160624
|
||||
))
|
||||
|
||||
ai_surrogate_species_input = c("H", "O", "Ba", "Cl", "S", "Sr", "Barite_kin", "Celestite_kin")
|
||||
ai_surrogate_species_output = c("O", "Ba", "S", "Sr", "Barite_kin", "Celestite_kin")
|
||||
|
||||
|
||||
threshold <- list(species = "Cl", value = 2E-10)
|
||||
|
||||
preprocess <- function(df) {
|
||||
if (!is.data.frame(df))
|
||||
df <- as.data.frame(df, check.names = FALSE)
|
||||
|
||||
as.data.frame(lapply(colnames(df),
|
||||
function(x) scale_standardizer(x=df[x],
|
||||
mean=standard$mean[x],
|
||||
scale=standard$scale[x],
|
||||
backtransform=FALSE)),
|
||||
check.names = FALSE)
|
||||
}
|
||||
|
||||
postprocess <- function(df) {
|
||||
if (!is.data.frame(df))
|
||||
df <- as.data.frame(df, check.names = FALSE)
|
||||
|
||||
as.data.frame(lapply(colnames(df),
|
||||
function(x) scale_standardizer(x=df[x],
|
||||
mean=standard$mean[x],
|
||||
scale=standard$scale[x],
|
||||
backtransform=TRUE)),
|
||||
check.names = FALSE)
|
||||
}
|
||||
|
||||
mass_balance <- function(predictors, prediction) {
|
||||
dBa <- abs(prediction$Ba + prediction$Barite -
|
||||
predictors$Ba - predictors$Barite)
|
||||
dSr <- abs(prediction$Sr + prediction$Celestite -
|
||||
predictors$Sr - predictors$Celestite)
|
||||
return(dBa + dSr)
|
||||
dBa <- abs(prediction$Ba + prediction$Barite_kin -
|
||||
predictors$Ba - predictors$Barite_kin)
|
||||
dSr <- abs(prediction$Sr + prediction$Celestite_kin -
|
||||
predictors$Sr - predictors$Celestite_kin)
|
||||
dS <- abs(prediction$S + prediction$Celestite_kin + prediction$Barite_kin -
|
||||
predictors$S - predictors$Celestite_kin - predictors$Barite_kin)
|
||||
return(dBa + dSr + dS)
|
||||
}
|
||||
|
||||
validate_predictions <- function(predictors, prediction) {
|
||||
epsilon <- 3e-5
|
||||
mb <- mass_balance(predictors, prediction)
|
||||
msgm("Mass balance mean:", mean(mb))
|
||||
msgm("Mass balance variance:", var(mb))
|
||||
msgm("Rows where mass balance meets threshold", epsilon, ":",
|
||||
sum(mb < epsilon))
|
||||
return(mb < epsilon)
|
||||
}
|
||||
epsilon <- 1E-5
|
||||
mb <- mass_balance(predictors, prediction)
|
||||
msgm("Mass balance mean:", mean(mb))
|
||||
msgm("Mass balance variance:", var(mb))
|
||||
ret <- mb < epsilon
|
||||
msgm("Rows where mass balance meets threshold", epsilon, ":",
|
||||
sum(ret), "/", nrow(predictors))
|
||||
return(ret)
|
||||
}
|
||||
Binary file not shown.
@ -1,32 +0,0 @@
|
||||
grid_def <- matrix(c(2, 3), nrow = 2, ncol = 5)
|
||||
|
||||
# Define grid configuration for POET model
|
||||
grid_setup <- list(
|
||||
pqc_in_file = "./barite_het.pqi",
|
||||
pqc_db_file = "./db_barite.dat", # Path to the database file for Phreeqc
|
||||
grid_def = grid_def, # Definition of the grid, containing IDs according to the Phreeqc input script
|
||||
grid_size = c(ncol(grid_def), nrow(grid_def)), # Size of the grid in meters
|
||||
constant_cells = c() # IDs of cells with constant concentration
|
||||
)
|
||||
|
||||
diffusion_setup <- list(
|
||||
boundaries = list(
|
||||
"W" = list(
|
||||
"type" = rep("constant", nrow(grid_def)),
|
||||
"sol_id" = rep(4, nrow(grid_def)),
|
||||
"cell" = seq_len(nrow(grid_def))
|
||||
)
|
||||
),
|
||||
alpha_x = 1e-6,
|
||||
alpha_y = matrix(runif(10, 1e-8, 1e-7),
|
||||
nrow = nrow(grid_def),
|
||||
ncol = ncol(grid_def)
|
||||
)
|
||||
)
|
||||
|
||||
# Define a setup list for simulation configuration
|
||||
setup <- list(
|
||||
Grid = grid_setup, # Parameters related to the grid structure
|
||||
Diffusion = diffusion_setup, # Parameters related to the diffusion process
|
||||
Chemistry = list()
|
||||
)
|
||||
@ -1,80 +0,0 @@
|
||||
## Initial: everywhere equilibrium with Celestite NB: The aqueous
|
||||
## solution *resulting* from this calculation is to be used as initial
|
||||
## state everywhere in the domain
|
||||
SOLUTION 1
|
||||
units mol/kgw
|
||||
water 1
|
||||
temperature 25
|
||||
pH 7
|
||||
pe 4
|
||||
S(6) 1e-12
|
||||
Sr 1e-12
|
||||
Ba 1e-12
|
||||
Cl 1e-12
|
||||
PURE 1
|
||||
Celestite 0.0 1
|
||||
|
||||
SAVE SOLUTION 2 # <- phreeqc keyword to store and later reuse these results
|
||||
END
|
||||
|
||||
RUN_CELLS
|
||||
-cells 1
|
||||
|
||||
COPY solution 1 2-3
|
||||
|
||||
## Here a 5x2 domain:
|
||||
|
||||
|---+---+---+---+---|
|
||||
-> | 2 | 2 | 2 | 2 | 2 |
|
||||
4 |---+---+---+---+---|
|
||||
-> | 3 | 3 | 3 | 3 | 3 |
|
||||
|---+---+---+---+---|
|
||||
|
||||
## East boundary: "injection" of solution 4. North, W, S boundaries: closed
|
||||
|
||||
## Here the two distinct zones: nr 2 with kinetics Celestite (initial
|
||||
## amount is 0, is then allowed to precipitate) and nr 3 with kinetic
|
||||
## Celestite and Barite (both initially > 0) where the actual
|
||||
## replacement takes place
|
||||
|
||||
#USE SOLUTION 2 <- PHREEQC keyword to reuse the results from previous calculation
|
||||
KINETICS 2
|
||||
Celestite
|
||||
-m 0 # Allowed to precipitate
|
||||
-parms 10.0
|
||||
-tol 1e-9
|
||||
|
||||
END
|
||||
|
||||
#USE SOLUTION 2 <- PHREEQC keyword to reuse the results from previous calculation
|
||||
KINETICS 3
|
||||
Barite
|
||||
-m 0.001
|
||||
-parms 50.
|
||||
-tol 1e-9
|
||||
Celestite
|
||||
-m 1
|
||||
-parms 10.0
|
||||
-tol 1e-9
|
||||
END
|
||||
|
||||
## A BaCl2 solution (nr 4) is "injected" from the left boundary:
|
||||
SOLUTION 4
|
||||
units mol/kgw
|
||||
pH 7
|
||||
water 1
|
||||
temp 25
|
||||
Ba 0.1
|
||||
Cl 0.2
|
||||
END
|
||||
## NB: again, the *result* of the SOLUTION 4 script defines the
|
||||
## concentration of all elements (+charge, tot H, tot O)
|
||||
|
||||
## Ideally, in the initial state SOLUTION 1 we should not have to
|
||||
## define the 4 elemental concentrations (S(6), Sr, Ba and Cl) but
|
||||
## obtain them having run once the scripts with the aqueous solution
|
||||
## resulting from SOLUTION 1 once with KINETICS 2 and once with
|
||||
## KINETICS 3.
|
||||
|
||||
RUN_CELLS
|
||||
-cells 2-4
|
||||
@ -1,4 +0,0 @@
|
||||
list(
|
||||
timesteps = rep(50, 100),
|
||||
store_result = TRUE
|
||||
)
|
||||
BIN
bench/barite/barite_trained.weights.h5
Normal file
BIN
bench/barite/barite_trained.weights.h5
Normal file
Binary file not shown.
@ -1,17 +1,19 @@
|
||||
set(bench_files
|
||||
dolo_inner_large.R
|
||||
dolo_interp.R
|
||||
)
|
||||
|
||||
set(runtime_files
|
||||
dolo_inner_large_rt.R
|
||||
dolo_interp_rt.R
|
||||
dolo_interp_rt_dt2000.R
|
||||
)
|
||||
|
||||
set(weight_files
|
||||
dolomite_trained.weights.h5)
|
||||
|
||||
ADD_BENCH_TARGET(
|
||||
dolo_bench
|
||||
bench_files
|
||||
runtime_files
|
||||
weight_files
|
||||
"dolo"
|
||||
)
|
||||
|
||||
|
||||
74
bench/dolo/dolo_ai_surrogate_input_script.R
Normal file
74
bench/dolo/dolo_ai_surrogate_input_script.R
Normal file
@ -0,0 +1,74 @@
|
||||
scale_min_max <- function(x, min, max, backtransform) {
|
||||
if (backtransform) {
|
||||
return((x * (max - min)) + min)
|
||||
} else {
|
||||
return((x - min) / (max - min))
|
||||
}
|
||||
}
|
||||
|
||||
scale_standardizer <- function(x, mean, scale, backtransform) {
|
||||
if(backtransform){
|
||||
return(x * scale + mean)
|
||||
}
|
||||
else{
|
||||
return((x-mean) / scale)
|
||||
}
|
||||
}
|
||||
|
||||
standard <- list(mean = c(H = 111.0124335959659, O=55.5065739707202, 'C(-4)'=1.5788555695339323e-15, 'C(4)'=0.00011905649680154037,
|
||||
Ca= 0.00012525858032576948, Cl=0.00010368471137502122, Mg=4.5640346338857756e-05, Calcite_kin=0.0001798444527389999,
|
||||
Dolomite_kin=7.6152065281986634e-06),
|
||||
scale = c(H=1.0, O=3.54850912318837e-05, 'C(-4)'=2.675559053860093e-14, 'C(4)'=1.1829735682920146e-05, Ca=1.207381343127647e-05, Cl=0.00024586541554245565,
|
||||
Mg=0.00011794307217698012, Calcite_kin=5.946457663332385e-05, Dolomite_kin=2.688201435907049e-05))
|
||||
|
||||
|
||||
ai_surrogate_species_input = c("H", "O", "C(-4)", "C(4)", "Ca", "Cl", "Mg", "Calcite_kin", "Dolomite_kin")
|
||||
ai_surrogate_species_output = c("H", "O", "C(-4)", "C(4)", "Ca", "Mg", "Calcite_kin", "Dolomite_kin")
|
||||
|
||||
|
||||
threshold <- list(species = "Cl", value = 2E-10)
|
||||
|
||||
preprocess <- function(df) {
|
||||
if (!is.data.frame(df))
|
||||
df <- as.data.frame(df, check.names = FALSE)
|
||||
|
||||
as.data.frame(lapply(colnames(df),
|
||||
function(x) scale_standardizer(x=df[x],
|
||||
mean=standard$mean[x],
|
||||
scale=standard$scale[x],
|
||||
backtransform=FALSE)),
|
||||
check.names = FALSE)
|
||||
}
|
||||
|
||||
postprocess <- function(df) {
|
||||
if (!is.data.frame(df))
|
||||
df <- as.data.frame(df, check.names = FALSE)
|
||||
|
||||
as.data.frame(lapply(colnames(df),
|
||||
function(x) scale_standardizer(x=df[x],
|
||||
mean=standard$mean[x],
|
||||
scale=standard$scale[x],
|
||||
backtransform=TRUE)),
|
||||
check.names = FALSE)
|
||||
}
|
||||
|
||||
mass_balance <- function(predictors, prediction) {
|
||||
dCa <- abs(prediction$Ca + prediction$Calcite_kin + prediction$Dolomite_kin -
|
||||
predictors$Ca - predictors$Calcite_kin - predictors$Dolomite_kin)
|
||||
dC <- abs(prediction$'C(-4)' + prediction$'C(4)' + prediction$Calcite_kin + 2 * prediction$Dolomite_kin
|
||||
- predictors$'C(-4)' - predictors$'C(4)' - predictors$Calcite_kin - 2 * predictors$Dolomite_kin)
|
||||
dMg <- abs(prediction$Mg + prediction$Dolomite_kin -
|
||||
predictors$Mg - predictors$Dolomite_kin)
|
||||
return(dCa + dC + dMg)
|
||||
}
|
||||
|
||||
validate_predictions <- function(predictors, prediction) {
|
||||
epsilon <- 1E-8
|
||||
mb <- mass_balance(predictors, prediction)
|
||||
msgm("Mass balance mean:", mean(mb))
|
||||
msgm("Mass balance variance:", var(mb))
|
||||
ret <- mb < epsilon
|
||||
msgm("Rows where mass balance meets threshold", epsilon, ":",
|
||||
sum(ret))
|
||||
return(ret)
|
||||
}
|
||||
Binary file not shown.
@ -1,115 +0,0 @@
|
||||
rows <- 2000
|
||||
cols <- 1000
|
||||
|
||||
grid_def <- matrix(2, nrow = rows, ncol = cols)
|
||||
|
||||
# Define grid configuration for POET model
|
||||
grid_setup <- list(
|
||||
pqc_in_file = "./dol.pqi",
|
||||
pqc_db_file = "./phreeqc_kin.dat", # Path to the database file for Phreeqc
|
||||
grid_def = grid_def, # Definition of the grid, containing IDs according to the Phreeqc input script
|
||||
grid_size = c(cols, rows) / 100, # Size of the grid in meters
|
||||
constant_cells = c() # IDs of cells with constant concentration
|
||||
)
|
||||
|
||||
bound_size <- 2
|
||||
|
||||
diffusion_setup <- list(
|
||||
inner_boundaries = list(
|
||||
"row" = c(400, 1400, 1600),
|
||||
"col" = c(200, 800, 800),
|
||||
"sol_id" = c(3, 4, 4)
|
||||
),
|
||||
alpha_x = 1e-6,
|
||||
alpha_y = 1e-6
|
||||
)
|
||||
|
||||
check_sign_cal_dol_dht <- function(old, new) {
|
||||
if ((old["Calcite"] == 0) != (new["Calcite"] == 0)) {
|
||||
return(TRUE)
|
||||
}
|
||||
if ((old["Dolomite"] == 0) != (new["Dolomite"] == 0)) {
|
||||
return(TRUE)
|
||||
}
|
||||
return(FALSE)
|
||||
}
|
||||
|
||||
fuzz_input_dht_keys <- function(input) {
|
||||
dht_species <- c(
|
||||
"H" = 3,
|
||||
"O" = 3,
|
||||
"Charge" = 3,
|
||||
"C(4)" = 6,
|
||||
"Ca" = 6,
|
||||
"Cl" = 3,
|
||||
"Mg" = 5,
|
||||
"Calcite" = 4,
|
||||
"Dolomite" = 4
|
||||
)
|
||||
return(input[names(dht_species)])
|
||||
}
|
||||
|
||||
check_sign_cal_dol_interp <- function(to_interp, data_set) {
|
||||
dht_species <- c(
|
||||
"H" = 3,
|
||||
"O" = 3,
|
||||
"Charge" = 3,
|
||||
"C(4)" = 6,
|
||||
"Ca" = 6,
|
||||
"Cl" = 3,
|
||||
"Mg" = 5,
|
||||
"Calcite" = 4,
|
||||
"Dolomite" = 4
|
||||
)
|
||||
data_set <- as.data.frame(do.call(rbind, data_set), check.names = FALSE, optional = TRUE)
|
||||
names(data_set) <- names(dht_species)
|
||||
cal <- (data_set$Calcite == 0) == (to_interp["Calcite"] == 0)
|
||||
dol <- (data_set$Dolomite == 0) == (to_interp["Dolomite"] == 0)
|
||||
|
||||
cal_dol_same_sig <- cal == dol
|
||||
return(rev(which(!cal_dol_same_sig)))
|
||||
}
|
||||
|
||||
check_neg_cal_dol <- function(result) {
|
||||
neg_sign <- (result["Calcite"] < 0) || (result["Dolomite"] < 0)
|
||||
return(neg_sign)
|
||||
}
|
||||
|
||||
# Optional when using Interpolation (example with less key species and custom
|
||||
# significant digits)
|
||||
|
||||
pht_species <- c(
|
||||
"C(4)" = 3,
|
||||
"Ca" = 3,
|
||||
"Mg" = 2,
|
||||
"Calcite" = 2,
|
||||
"Dolomite" = 2
|
||||
)
|
||||
|
||||
chemistry_setup <- list(
|
||||
dht_species = c(
|
||||
"H" = 3,
|
||||
"O" = 3,
|
||||
"Charge" = 3,
|
||||
"C(4)" = 6,
|
||||
"Ca" = 6,
|
||||
"Cl" = 3,
|
||||
"Mg" = 5,
|
||||
"Calcite" = 4,
|
||||
"Dolomite" = 4
|
||||
),
|
||||
pht_species = pht_species,
|
||||
hooks = list(
|
||||
dht_fill = check_sign_cal_dol_dht,
|
||||
dht_fuzz = fuzz_input_dht_keys,
|
||||
interp_pre = check_sign_cal_dol_interp,
|
||||
interp_post = check_neg_cal_dol
|
||||
)
|
||||
)
|
||||
|
||||
# Define a setup list for simulation configuration
|
||||
setup <- list(
|
||||
Grid = grid_setup, # Parameters related to the grid structure
|
||||
Diffusion = diffusion_setup, # Parameters related to the diffusion process
|
||||
Chemistry = chemistry_setup # Parameters related to the chemistry process
|
||||
)
|
||||
@ -1,10 +0,0 @@
|
||||
iterations <- 500
|
||||
dt <- 50
|
||||
|
||||
out_save <- seq(5, iterations, by = 5)
|
||||
|
||||
list(
|
||||
timesteps = rep(dt, iterations),
|
||||
store_result = TRUE,
|
||||
out_save = out_save
|
||||
)
|
||||
@ -7,6 +7,7 @@ grid_def <- matrix(2, nrow = rows, ncol = cols)
|
||||
grid_setup <- list(
|
||||
pqc_in_file = "./dol.pqi",
|
||||
pqc_db_file = "./phreeqc_kin.dat",
|
||||
pqc_with_redox = TRUE,
|
||||
grid_def = grid_def,
|
||||
grid_size = c(5, 2.5),
|
||||
constant_cells = c()
|
||||
@ -120,7 +121,8 @@ chemistry_setup <- list(
|
||||
## dht_fuzz = fuzz_input_dht_keys,
|
||||
interp_pre = check_sign_cal_dol_interp,
|
||||
interp_post = check_neg_cal_dol
|
||||
)
|
||||
),
|
||||
ai_surrogate_input_script = "./dolo_ai_surrogate_input_script.R"
|
||||
)
|
||||
|
||||
## Define a setup list for simulation configuration
|
||||
|
||||
@ -1,10 +0,0 @@
|
||||
iterations <- 2000
|
||||
dt <- 200
|
||||
|
||||
out_save <- c(1, 10, 20, seq(40, iterations, by = 40))
|
||||
|
||||
list(
|
||||
timesteps = rep(dt, iterations),
|
||||
store_result = TRUE,
|
||||
out_save = out_save
|
||||
)
|
||||
BIN
bench/dolo/dolomite_trained.weights.h5
Normal file
BIN
bench/dolo/dolomite_trained.weights.h5
Normal file
Binary file not shown.
@ -1 +1 @@
|
||||
Subproject commit 2dd2b8881d6fe27b08a259d48ee8bca6188f049a
|
||||
Subproject commit 3631ecb08c10804ae9da158074bc9299fea7b810
|
||||
17
naaice_runs.md
Normal file
17
naaice_runs.md
Normal file
@ -0,0 +1,17 @@
|
||||
# NAAICE Measurements
|
||||
|
||||
### Barite
|
||||
|
||||
1. With AI Surrogate and single retraining
|
||||
```NAA_SPEC=10.3.10.42:12345:1:5 mpirun -n 24 ./poet --ai --ai-backend=2 --fn=1 ./barite_200_rt.R ./barite_200.qs2 output_barite_ai```
|
||||
|
||||
2. With copy/expert knowledge functionality
|
||||
```mpirun -n 24 ./poet -c ./barite_200_rt.R ./barite_200.qs2 output_barite_expert_knowledge```
|
||||
|
||||
### Dolomite
|
||||
|
||||
1. With AI Surrogate and single retraining
|
||||
```NAA_SPEC=10.3.10.42:12345:2:5 mpirun -n 24 ./poet --ai --ai-backend=2 --fn=2 ./dolo_interp_rt_dt2000.R ./dolo_interp.qs2 output_dolo_ai```
|
||||
|
||||
2. With copy/expert knowledge functionality
|
||||
```mpirun -n 24 ./poet -c ./dolo_interp_rt_dt2000.R ./dolo_interp.qs2 output_dolo_expert_knowledge```
|
||||
226
poet.yml
226
poet.yml
@ -1,4 +1,4 @@
|
||||
name: poet
|
||||
name: zib
|
||||
channels:
|
||||
- defaults
|
||||
- conda-forge
|
||||
@ -7,144 +7,126 @@ channels:
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=conda_forge
|
||||
- _openmp_mutex=4.5=2_gnu
|
||||
- _r-mutex=1.0.1=anacondar_1
|
||||
- attr=2.5.2=h39aace5_0
|
||||
- binutils_impl_linux-64=2.43=h4bf12b8_5
|
||||
- binutils_linux-64=2.43=h4852527_5
|
||||
- bzip2=1.0.8=h5eee18b_6
|
||||
- c-ares=1.34.5=hb9d3cd8_0
|
||||
- ca-certificates=2025.11.12=hbd8a1cb_0
|
||||
- cached-property=1.5.2=py_0
|
||||
- cmake=4.1.2=hc946e07_0
|
||||
- eigen=3.4.0=h171cf75_1
|
||||
- expat=2.7.1=h6a678d5_0
|
||||
- binutils_impl_linux-64=2.45=default_hfdba357_104
|
||||
- binutils_linux-64=2.28.1=he4fe6c7_1
|
||||
- bwidget=1.10.1=ha770c72_1
|
||||
- bzip2=1.0.8=hda65f42_8
|
||||
- c-ares=1.34.6=hb03c661_0
|
||||
- ca-certificates=2025.12.2=h06a4308_0
|
||||
- cairo=1.18.4=h3394656_0
|
||||
- cmake=4.2.1=hc85cc9f_0
|
||||
- conda-gcc-specs=13.3.0=h1ae1ff9_2
|
||||
- curl=8.17.0=h4e3cde8_1
|
||||
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
|
||||
- font-ttf-inconsolata=3.000=h77eed37_0
|
||||
- font-ttf-source-code-pro=2.038=h77eed37_0
|
||||
- font-ttf-ubuntu=0.83=h77eed37_3
|
||||
- fontconfig=2.15.0=h7e30c49_1
|
||||
- fonts-conda-ecosystem=1=0
|
||||
- fonts-conda-forge=1=hc364b38_1
|
||||
- freetype=2.14.1=ha770c72_0
|
||||
- fribidi=1.0.16=hb03c661_0
|
||||
- gcc=13.3.0=h9576a4e_2
|
||||
- gcc_impl_linux-64=13.3.0=h1e990d8_2
|
||||
- gcc_linux-64=13.3.0=h6f18a23_11
|
||||
- gettext=0.21.0=hedfda30_2
|
||||
- gfortran_impl_linux-64=13.3.0=h84c1745_2
|
||||
- graphite2=1.3.14=hecca717_2
|
||||
- gsl=2.7=he838d99_0
|
||||
- gxx=13.3.0=h9576a4e_2
|
||||
- gxx_impl_linux-64=13.3.0=hae580e1_2
|
||||
- gxx_linux-64=13.3.0=hb14504d_11
|
||||
- h5py=3.14.0=nompi_py313hfaf8fd4_100
|
||||
- hdf5=1.14.6=nompi_h6e4c0c1_103
|
||||
- harfbuzz=12.2.0=h15599e2_0
|
||||
- hdf5=1.14.6=nompi_h1b119a7_104
|
||||
- highfive=2.10.1=he6560a2_2
|
||||
- icu=73.1=h6a678d5_0
|
||||
- kernel-headers_linux-64=3.10.0=he073ed8_18
|
||||
- krb5=1.21.3=h143b758_0
|
||||
- ld_impl_linux-64=2.43=h712a8e2_5
|
||||
- icu=75.1=he02047a_0
|
||||
- kernel-headers_linux-64=4.18.0=he073ed8_8
|
||||
- keyutils=1.6.3=hb9d3cd8_0
|
||||
- krb5=1.21.3=h659f571_0
|
||||
- ld_impl_linux-64=2.45=default_hbd61a6d_104
|
||||
- lerc=4.0.0=h0aef613_1
|
||||
- libaec=1.1.4=h3f801dc_0
|
||||
- libblas=3.9.0=32_h59b9bed_openblas
|
||||
- libblas=3.11.0=4_h4a7cf45_openblas
|
||||
- libcap=2.77=h3ff7636_0
|
||||
- libcblas=3.9.0=32_he106b2a_openblas
|
||||
- libcurl=8.16.0=heebcbe5_0
|
||||
- libedit=3.1.20230828=h5eee18b_0
|
||||
- libev=4.33=h7f8727e_1
|
||||
- libevent=2.1.12=hf998b51_1
|
||||
- libexpat=2.7.1=hecca717_0
|
||||
- libfabric=2.3.1=ha770c72_1
|
||||
- libfabric1=2.3.1=h6c8fc0a_1
|
||||
- libffi=3.4.4=h6a678d5_1
|
||||
- libgcc=15.1.0=h767d61c_2
|
||||
- libcblas=3.11.0=4_h0358290_openblas
|
||||
- libcurl=8.17.0=h4e3cde8_1
|
||||
- libdeflate=1.25=h17f619e_0
|
||||
- libedit=3.1.20250104=pl5321h7949ede_0
|
||||
- libev=4.33=hd590300_2
|
||||
- libexpat=2.7.3=hecca717_0
|
||||
- libffi=3.5.2=h9ec8514_0
|
||||
- libfreetype=2.14.1=ha770c72_0
|
||||
- libfreetype6=2.14.1=h73754d4_0
|
||||
- libgcc=15.2.0=he0feb66_16
|
||||
- libgcc-devel_linux-64=13.3.0=hc03c837_102
|
||||
- libgcc-ng=15.1.0=h69a702a_2
|
||||
- libgfortran=15.1.0=h69a702a_2
|
||||
- libgfortran-ng=15.1.0=h69a702a_2
|
||||
- libgfortran5=15.1.0=hcea5267_2
|
||||
- libgomp=15.1.0=h767d61c_2
|
||||
- libhwloc=2.12.1=default_h3d81e11_1000
|
||||
- libiconv=1.16=h5eee18b_3
|
||||
- libidn2=2.3.8=hf80d704_0
|
||||
- liblapack=3.9.0=32_h7ac8fdf_openblas
|
||||
- libgcc-ng=15.2.0=h69a702a_16
|
||||
- libgfortran=15.2.0=h69a702a_16
|
||||
- libgfortran-ng=15.2.0=h69a702a_16
|
||||
- libgfortran5=15.2.0=h68bc16d_16
|
||||
- libglib=2.86.3=h6548e54_0
|
||||
- libgomp=15.2.0=he0feb66_16
|
||||
- libiconv=1.18=h3b78370_2
|
||||
- libjpeg-turbo=3.1.2=hb03c661_0
|
||||
- liblapack=3.11.0=4_h47877c9_openblas
|
||||
- liblzma=5.8.1=hb9d3cd8_2
|
||||
- liblzma-devel=5.8.1=hb9d3cd8_2
|
||||
- libmpdec=4.0.0=h5eee18b_0
|
||||
- libnghttp2=1.64.0=h161d5f1_0
|
||||
- libnghttp2=1.67.0=had1ee68_0
|
||||
- libnl=3.11.0=hb9d3cd8_0
|
||||
- libopenblas=0.3.30=pthreads_h94d23a6_0
|
||||
- libpmix=5.0.8=h4bd6b51_2
|
||||
- libopenblas=0.3.30=pthreads_h94d23a6_4
|
||||
- libpng=1.6.53=h421ea60_0
|
||||
- libsanitizer=13.3.0=he8ea267_2
|
||||
- libsqlite=3.50.2=h6cd9bfd_0
|
||||
- libssh2=1.11.1=hcf80075_0
|
||||
- libstdcxx=15.1.0=h8f9b012_2
|
||||
- libstdcxx=15.2.0=h934c35e_16
|
||||
- libstdcxx-devel_linux-64=13.3.0=hc03c837_102
|
||||
- libstdcxx-ng=15.1.0=h4852527_2
|
||||
- libstdcxx-ng=15.2.0=hdf11a46_16
|
||||
- libsystemd0=257.10=hd0affe5_2
|
||||
- libtiff=4.7.1=h9d88235_1
|
||||
- libudev1=257.10=hd0affe5_2
|
||||
- libunistring=1.3=hb25bd0a_0
|
||||
- libuuid=2.38.1=h0b41bf4_0
|
||||
- libuv=1.48.0=h5eee18b_0
|
||||
- libxcb=1.17.0=h9b100fa_0
|
||||
- libxml2=2.13.9=h2c43086_0
|
||||
- libuuid=2.41.2=h5347b49_1
|
||||
- libuv=1.51.0=hb03c661_1
|
||||
- libwebp-base=1.6.0=hd42ef1d_0
|
||||
- libxcb=1.17.0=h8a09558_0
|
||||
- libzlib=1.3.1=hb9d3cd8_2
|
||||
- mpi=1.0.1=openmpi
|
||||
- make=4.4.1=hb9d3cd8_2
|
||||
- ncurses=6.5=h2d0b736_3
|
||||
- numpy=2.3.0=py313h17eae1a_0
|
||||
- openmpi=5.0.8=h2fe1745_108
|
||||
- openssl=3.6.0=h26f9b46_0
|
||||
- pip=25.1=pyhc872135_2
|
||||
- pthread-stubs=0.3=h0ce48e5_1
|
||||
- pybind11=2.13.6=py313hdb19cb5_1
|
||||
- pybind11-global=2.13.6=py313hdb19cb5_1
|
||||
- python=3.13.2=hf636f53_101_cp313
|
||||
- python_abi=3.13=0_cp313
|
||||
- pango=1.56.4=hadf4263_0
|
||||
- pcre2=10.47=haa7fec5_0
|
||||
- pixman=0.46.4=h54a6638_1
|
||||
- pthread-stubs=0.4=hb9d3cd8_1002
|
||||
- r=4.5=r45hd8ed1ab_1009
|
||||
- r-base=4.5.2=h835929b_2
|
||||
- r-boot=1.3_32=r45hc72bb7e_1
|
||||
- r-class=7.3_23=r45h54b55ab_1
|
||||
- r-cluster=2.1.8.1=r45heaba542_1
|
||||
- r-codetools=0.2_20=r45hc72bb7e_2
|
||||
- r-foreign=0.8_90=r45h54b55ab_1
|
||||
- r-kernsmooth=2.23_26=r45ha0a88a1_1
|
||||
- r-lattice=0.22_7=r45h54b55ab_1
|
||||
- r-mass=7.3_65=r45h54b55ab_0
|
||||
- r-matrix=1.7_4=r45h0e4624f_1
|
||||
- r-mgcv=1.9_4=r45h0e4624f_0
|
||||
- r-nlme=3.1_168=r45heaba542_1
|
||||
- r-nnet=7.3_20=r45h54b55ab_1
|
||||
- r-recommended=4.5=r45hd8ed1ab_1008
|
||||
- r-rpart=4.1.24=r45h54b55ab_1
|
||||
- r-spatial=7.3_18=r45h54b55ab_1
|
||||
- r-survival=3.8_3=r45h54b55ab_1
|
||||
- rdma-core=60.0=hecca717_0
|
||||
- readline=8.2=h5eee18b_0
|
||||
- rhash=1.4.6=ha914fed_0
|
||||
- setuptools=78.1.1=py313h06a4308_0
|
||||
- sqlite=3.31.1=h7b6447c_0
|
||||
- sysroot_linux-64=2.17=h0157908_18
|
||||
- tk=8.6.13=noxft_hd72426e_102
|
||||
- ucc=1.6.0=hb729f83_0
|
||||
- ucx=1.19.0=h63b5c0b_5
|
||||
- wheel=0.45.1=py313h06a4308_0
|
||||
- xorg-libx11=1.8.12=h9b100fa_1
|
||||
- xorg-libxau=1.0.12=h9b100fa_0
|
||||
- xorg-libxdmcp=1.1.5=h9b100fa_0
|
||||
- xorg-xorgproto=2024.1=h5eee18b_1
|
||||
- xz=5.8.1=hbcc6ac9_2
|
||||
- xz-gpl-tools=5.8.1=hbcc6ac9_2
|
||||
- xz-tools=5.8.1=hb9d3cd8_2
|
||||
- zlib=1.3.1=hb9d3cd8_2
|
||||
- zstd=1.5.7=hb8e6e7a_2
|
||||
- pip:
|
||||
- absl-py==2.3.1
|
||||
- ai-benchmarks-utils==0.1.0
|
||||
- astunparse==1.6.3
|
||||
- certifi==2025.11.12
|
||||
- charset-normalizer==3.4.4
|
||||
- flatbuffers==25.9.23
|
||||
- gast==0.6.0
|
||||
- google-pasta==0.2.0
|
||||
- grpcio==1.76.0
|
||||
- idna==3.11
|
||||
- joblib==1.5.2
|
||||
- keras==3.12.0
|
||||
- libclang==18.1.1
|
||||
- markdown==3.10
|
||||
- markdown-it-py==4.0.0
|
||||
- markupsafe==3.0.3
|
||||
- mdurl==0.1.2
|
||||
- ml-dtypes==0.5.4
|
||||
- namex==0.1.0
|
||||
- opt-einsum==3.4.0
|
||||
- optree==0.18.0
|
||||
- packaging==25.0
|
||||
- pandas==2.3.3
|
||||
- pillow==12.0.0
|
||||
- protobuf==6.33.1
|
||||
- pygments==2.19.2
|
||||
- python-dateutil==2.9.0.post0
|
||||
- pytz==2025.2
|
||||
- requests==2.32.5
|
||||
- rich==14.2.0
|
||||
- scikit-learn==1.7.2
|
||||
- scipy==1.16.3
|
||||
- six==1.17.0
|
||||
- tensorboard==2.20.0
|
||||
- tensorboard-data-server==0.7.2
|
||||
- tensorflow==2.20.0
|
||||
- termcolor==3.2.0
|
||||
- threadpoolctl==3.6.0
|
||||
- typing-extensions==4.15.0
|
||||
- tzdata==2025.2
|
||||
- urllib3==2.5.0
|
||||
- werkzeug==3.1.3
|
||||
- wrapt==2.0.1
|
||||
prefix: /mnt/scratch/miniconda3/envs/poet-dummy
|
||||
- readline=8.2=h8c095d6_2
|
||||
- rhash=1.4.6=hb9d3cd8_1
|
||||
- sed=4.9=h6688a6e_0
|
||||
- sysroot_linux-64=2.28=h4ee821c_8
|
||||
- tk=8.6.13=noxft_ha0e22de_103
|
||||
- tktable=2.10=h8d826fa_7
|
||||
- tzdata=2025b=h78e105d_0
|
||||
- xorg-libice=1.1.2=hb9d3cd8_0
|
||||
- xorg-libsm=1.2.6=he73a12e_0
|
||||
- xorg-libx11=1.8.12=h4f16b4b_0
|
||||
- xorg-libxau=1.0.12=hb03c661_1
|
||||
- xorg-libxdmcp=1.1.5=hb03c661_1
|
||||
- xorg-libxext=1.3.6=hb9d3cd8_0
|
||||
- xorg-libxrender=0.9.12=hb9d3cd8_0
|
||||
- xorg-libxt=1.3.1=hb9d3cd8_0
|
||||
- zstd=1.5.7=hb78ec9c_6
|
||||
prefix: /mnt/scratch/miniconda3/envs/zib
|
||||
|
||||
105
src/poet.cpp
105
src/poet.cpp
@ -46,7 +46,6 @@
|
||||
|
||||
#include <Model.hpp>
|
||||
#include <NAABackend.hpp>
|
||||
#include <PythonBackend.hpp>
|
||||
#include <TrainingBackend.hpp>
|
||||
#include <TrainingData.hpp>
|
||||
|
||||
@ -166,6 +165,9 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) {
|
||||
app.add_flag("-c,--copy-non-reactive", params.copy_non_reactive_regions,
|
||||
"Copy non-reactive regions instead of computing them");
|
||||
|
||||
app.add_option("-f,--fn", params.function_code, "Function code for the NAA")
|
||||
->check(CLI::PositiveNumber);
|
||||
|
||||
app.add_flag("--rds", params.as_rds,
|
||||
"Save output as .rds file instead of default .qs2");
|
||||
|
||||
@ -308,11 +310,17 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps());
|
||||
|
||||
std::unique_ptr<AIContext> ai_ctx = nullptr;
|
||||
size_t retrain_counter = 0;
|
||||
size_t field_size = 0;
|
||||
|
||||
if (params.ai) {
|
||||
|
||||
ai_ctx = std::make_unique<AIContext>(
|
||||
"/mnt/scratch/signer/poet/bench/barite/barite_trained.weights.h5");
|
||||
if (params.function_code == 1) {
|
||||
ai_ctx = std::make_unique<AIContext>("./barite_trained.weights.h5");
|
||||
} else if (params.function_code == 2) {
|
||||
ai_ctx = std::make_unique<AIContext>("./dolomite_trained.weights.h5");
|
||||
}
|
||||
|
||||
R.parseEval(
|
||||
"mean <- as.numeric(standard$mean[ai_surrogate_species_output])");
|
||||
R.parseEval(
|
||||
@ -321,17 +329,19 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
std::vector<float> mean = R["mean"];
|
||||
std::vector<float> scale = R["scale"];
|
||||
|
||||
field_size = chem.getField().GetRequestedVecSize();
|
||||
std::cout << field_size << std::endl;
|
||||
|
||||
ai_ctx->scaler.set_scaler(mean, scale);
|
||||
|
||||
// initialzie training backens only if retraining is desired
|
||||
if (params.ai_backend == PYTHON_BACKEND) {
|
||||
MSG("AI Surrogate with Python/keras backend enabled.")
|
||||
ai_ctx->training_backend =
|
||||
std::make_unique<PythonBackend<ai_type_t>>(4 * params.batch_size);
|
||||
std::cerr << "Not implemented" << std::endl;
|
||||
} else if (params.ai_backend == NAA_BACKEND) {
|
||||
MSG("AI Surrogate with NAA backend enabled.")
|
||||
ai_ctx->training_backend =
|
||||
std::make_unique<NAABackend<ai_type_t>>(4 * params.batch_size);
|
||||
std::make_unique<NAABackend<ai_type_t>>(4 * field_size);
|
||||
}
|
||||
|
||||
if (!params.disable_retraining) {
|
||||
@ -339,7 +349,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
ai_ctx->design_buffer, ai_ctx->results_buffer, ai_ctx->model,
|
||||
ai_ctx->meta_params, ai_ctx->scaler, ai_ctx->data_semaphore_write,
|
||||
ai_ctx->data_semaphore_read, ai_ctx->model_semaphore,
|
||||
ai_ctx->training_is_running, 1);
|
||||
ai_ctx->training_is_running, params.function_code);
|
||||
}
|
||||
}
|
||||
|
||||
@ -361,6 +371,21 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
/* run transport */
|
||||
diffusion.simulate(dt);
|
||||
|
||||
// validity_vector:
|
||||
// vector with length = number of grid cells (1 indicates
|
||||
// that values are copied into the next time step (no chemical reaction)
|
||||
// or the ai surrogate prediction was good enoug. In both cases the
|
||||
// workers skip the exact simulation with PHREEQC) predictor_idx
|
||||
// ai_validity_vector
|
||||
// predictors:
|
||||
// data frame with elements that are used as input for the ai surrogate
|
||||
// model. The data frame can be smaller than the grid size if
|
||||
// copy_non_reactive_regions option is enabled. In this case only the
|
||||
// reactive data are used for ai prediction.
|
||||
// targets:
|
||||
// data frame with elements that are used as ai outputs for the
|
||||
// retraining step.
|
||||
|
||||
if (params.ai || params.copy_non_reactive_regions) {
|
||||
|
||||
chem.getField().update(diffusion.getField());
|
||||
@ -374,42 +399,39 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
R.parseEval("validity_vector <- rep(FALSE, nrow(field))");
|
||||
|
||||
if (params.copy_non_reactive_regions) {
|
||||
R.parseEval("validity_vector <- field$Cl < 1e-14");
|
||||
R.parseEval(
|
||||
"validity_vector <- field[[threshold$species]] < threshold$value");
|
||||
}
|
||||
}
|
||||
|
||||
// MSG("Chemistry start");
|
||||
MSG("Chemistry start");
|
||||
if (params.ai) {
|
||||
double ai_start_t = MPI_Wtime();
|
||||
|
||||
// deep copy field
|
||||
R.parseEval("predictors <- data.frame(field)");
|
||||
R.parseEval("predictors <- field");
|
||||
// get only ai related species
|
||||
R.parseEval("predictors <- predictors[ai_surrogate_species_input]");
|
||||
|
||||
// remove already copied values
|
||||
R.parseEval("predictors <- predictors[!validity_vector,]");
|
||||
|
||||
R.parseEval(
|
||||
"print(paste('Length of predictors:', length(predictors$H)))");
|
||||
|
||||
// store row names of predictors
|
||||
R.parseEval("predictor_idx <- row.names(predictors)");
|
||||
|
||||
R.parseEval("print(head(predictors))");
|
||||
R.parseEval("predictors_scaled <- preprocess(predictors)");
|
||||
|
||||
std::vector<std::vector<float>> predictors_scaled =
|
||||
R["predictors_scaled"];
|
||||
|
||||
std::vector<float> predictions_scaled =
|
||||
ai_ctx->model.predict(predictors_scaled, params.batch_size,
|
||||
ai_ctx->model_semaphore); // features per cell
|
||||
|
||||
int n_samples = R.parseEval("nrow(predictors)");
|
||||
int n_output_features = ai_ctx->model.weight_matrices.back().cols();
|
||||
std::cout << "n_output_features: " << n_output_features << std::endl;
|
||||
std::vector<double> predictions_scaled_double(predictions_scaled.begin(),
|
||||
predictions_scaled.end());
|
||||
|
||||
R["TMP"] = predictions_scaled_double;
|
||||
R["n_samples"] = n_samples;
|
||||
R["n_output"] = n_output_features;
|
||||
@ -417,26 +439,20 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
R.parseEval("predictions_scaled <- setNames(data.frame(matrix(TMP, "
|
||||
"nrow=n_samples, ncol=n_output, byrow=TRUE)), "
|
||||
"ai_surrogate_species_output)");
|
||||
// R.parseEval("print(head(predictions_scaled))");
|
||||
|
||||
R.parseEval("predictions <- postprocess(predictions_scaled)");
|
||||
// R.parseEval("print(head(predictions))");
|
||||
|
||||
MSG("AI Validation");
|
||||
|
||||
R.parseEval("ai_validity_vector <- validate_predictions(predictors, "
|
||||
"predictions) ");
|
||||
|
||||
R.parseEval("print(length(predictor_idx))");
|
||||
R.parseEval("print(length(ai_validity_vector))");
|
||||
|
||||
// get only indices where prediction was valid
|
||||
R.parseEval("predictor_idx <- predictor_idx[ai_validity_vector]");
|
||||
|
||||
// set in global validity vector all elements to true, where prediction
|
||||
// was possible
|
||||
R.parseEval("validity_vector[predictor_idx] <- TRUE");
|
||||
|
||||
R.parseEval("print(head(validity_vector))");
|
||||
R.parseEval("validity_vector[as.numeric(predictor_idx)] <- TRUE");
|
||||
|
||||
MSG("AI TempField");
|
||||
// maybe row.names was overwritten by function calls ??
|
||||
@ -444,7 +460,8 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
// subset predictions to ai_validity_vector == TRUE
|
||||
R.parseEval("predictions <- predictions[ai_validity_vector,]");
|
||||
// merge predicted values into field stored in R
|
||||
R.parseEval("field[row.names(predictions),ai_surrogate_species_output] "
|
||||
R.parseEval("field[as.numeric(row.names(predictions)),ai_surrogate_"
|
||||
"species_output] "
|
||||
"<- predictions");
|
||||
|
||||
MSG("AI Set Field");
|
||||
@ -461,7 +478,6 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
|
||||
if (params.copy_non_reactive_regions || params.ai) {
|
||||
MSG("Set copied or predicted values for the workers");
|
||||
|
||||
R.parseEval(
|
||||
"print(paste('Number of valid cells:', sum(validity_vector)))");
|
||||
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
|
||||
@ -481,15 +497,11 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
|
||||
R.parseEval("predictors_retraining <- "
|
||||
"get_invalid_values(predictors_scaled, ai_validity_vector)");
|
||||
R.parseEval("print(head(predictors_retraining))");
|
||||
R.parseEval("targets <- targets[predictor_idx, ]");
|
||||
R.parseEval("targets_retraining <- "
|
||||
"get_invalid_values(targets[ai_surrogate_species_output], "
|
||||
"ai_validity_vector)");
|
||||
R.parseEval("print(length(predictors_scaled$H))");
|
||||
R.parseEval("print(length(ai_validity_vector))");
|
||||
R.parseEval("targets <- "
|
||||
"targets[as.numeric(row.names(predictors_retraining)), "
|
||||
"ai_surrogate_species_output]");
|
||||
|
||||
R.parseEval("targets_retraining <- preprocess(targets_retraining)");
|
||||
R.parseEval("targets_retraining <- preprocess(targets)");
|
||||
|
||||
std::vector<std::vector<float>> predictors_retraining =
|
||||
R["predictors_retraining"];
|
||||
@ -500,14 +512,11 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
|
||||
ai_ctx->data_semaphore_write.acquire();
|
||||
|
||||
std::cout << "size of predictors " << predictors_retraining[0].size()
|
||||
<< std::endl;
|
||||
std::cout << "size of targets " << targets_retraining[0].size()
|
||||
<< std::endl;
|
||||
|
||||
ai_ctx->design_buffer.addData(predictors_retraining);
|
||||
ai_ctx->results_buffer.addData(targets_retraining);
|
||||
|
||||
if (predictors_retraining[0].size() > 0 &&
|
||||
targets_retraining[0].size() > 0 && retrain_counter == 0) {
|
||||
ai_ctx->design_buffer.addData(predictors_retraining);
|
||||
ai_ctx->results_buffer.addData(targets_retraining);
|
||||
}
|
||||
size_t elements_design_buffer =
|
||||
ai_ctx->design_buffer.getSize() /
|
||||
(predictors_retraining.size() * sizeof(float));
|
||||
@ -520,11 +529,11 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
std::cout << "results_buffer_size: " << elements_results_buffer
|
||||
<< std::endl;
|
||||
|
||||
if (elements_design_buffer >=
|
||||
20 * params.batch_size && // TODO: change to 4 * grid_size
|
||||
elements_results_buffer >= 20 * params.batch_size &&
|
||||
ai_ctx->training_is_running == false) {
|
||||
if (elements_design_buffer >= 4 * field_size &&
|
||||
elements_results_buffer >= 4 * field_size &&
|
||||
ai_ctx->training_is_running == false && retrain_counter == 0) {
|
||||
ai_ctx->data_semaphore_read.release();
|
||||
retrain_counter++;
|
||||
} else {
|
||||
ai_ctx->data_semaphore_write.release();
|
||||
}
|
||||
@ -553,7 +562,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
|
||||
std::cout << std::endl;
|
||||
|
||||
if (!params.disable_retraining) {
|
||||
if (params.ai && !params.disable_retraining) {
|
||||
ai_ctx->training_backend->stop_training(ai_ctx->data_semaphore_read);
|
||||
}
|
||||
|
||||
@ -769,7 +778,7 @@ int main(int argc, char *argv[]) {
|
||||
R["out_ext"] = run_params.out_ext;
|
||||
R["out_dir"] = run_params.out_dir;
|
||||
|
||||
if (run_params.ai) {
|
||||
if (run_params.ai || run_params.copy_non_reactive_regions) {
|
||||
/* Incorporate ai surrogate from R */
|
||||
R.parseEvalQ(ai_surrogate_r_library);
|
||||
/* Use dht species for model input and output */
|
||||
|
||||
@ -23,6 +23,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
@ -84,9 +85,11 @@ struct RuntimeParameters {
|
||||
bool ai = false;
|
||||
bool disable_retraining = false;
|
||||
static constexpr std::uint8_t AI_BACKEND_DEFAULT = 1;
|
||||
std::uint8_t ai_backend = 1; // 1 - python, 2 - naa
|
||||
std::uint8_t ai_backend = AI_BACKEND_DEFAULT; // 1 - python, 2 - naa
|
||||
bool train_only_invalid = true;
|
||||
int batch_size = 1000;
|
||||
int batch_size = 2500;
|
||||
static constexpr std::uint8_t DEFAULT_FUNCTION_CODE = 0;
|
||||
std::uint8_t function_code = DEFAULT_FUNCTION_CODE;
|
||||
|
||||
static constexpr bool COPY_NON_REACTIVE_REGIONS = false;
|
||||
bool copy_non_reactive_regions = COPY_NON_REACTIVE_REGIONS;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user