mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-15 20:38:23 +01:00
150 lines
5.4 KiB
R
150 lines
5.4 KiB
R
## Time-stamp: "Last modified 2024-05-30 13:27:06 delucia"
|
|
|
|
## Use the global variable "ai_surrogate_base_path" when using file paths
|
|
## relative to the input script
|
|
model_file_path <- normalizePath(paste0(ai_surrogate_base_path,
|
|
"barite_50ai_all.keras"))
|
|
|
|
scale_min_max <- function(x, min, max, backtransform) {
|
|
if (backtransform) {
|
|
return((x * (max - min)) + min)
|
|
} else {
|
|
return((x - min) / (max - min))
|
|
}
|
|
}
|
|
|
|
## Apply decimal logarithms handling negative values
|
|
Safelog <- function (vec) {
|
|
rv <- range(vec)
|
|
if (max(abs(rv)) < 1) {
|
|
ret <- vec
|
|
ret[vec == 0] <- 0
|
|
ret[vec > 0] <- log10(vec[vec > 0])
|
|
ret[vec < 0] <- -log10(-vec[vec < 0])
|
|
} else {
|
|
ret <- log10(abs(vec))
|
|
}
|
|
ret
|
|
}
|
|
Safeexp <- function (vec) {
|
|
ret <- vec
|
|
ret[vec == 0] <- 0
|
|
ret[vec > 0] <- -10^(-vec[vec > 0])
|
|
ret[vec < 0] <- 10^(vec[vec < 0])
|
|
ret
|
|
}
|
|
|
|
##' @title Apply transformations to cols of a data.frame
|
|
##' @param df named data.frame
|
|
##' @param tlist list of function names
|
|
##' @return data.frame with the columns specified in tlist and the
|
|
##' transformed numerical values
|
|
##' @author MDL
|
|
TransfList <- function (df, tlist) {
|
|
vars <- intersect(colnames(df), names(tlist))
|
|
lens <- sapply(tlist[vars], length)
|
|
n <- max(lens)
|
|
filledlist <- lapply(tlist[vars],
|
|
function(x)
|
|
if (length(x) < n)
|
|
return(c(x, rep("I", n - length(x))))
|
|
else
|
|
return(x))
|
|
tmp <- df[, vars]
|
|
for (i in seq_len(n)) {
|
|
tmp <- as.data.frame(sapply(vars, function(x)
|
|
do.call(filledlist[[x]][i], list(tmp[, x]))))
|
|
}
|
|
return(tmp)
|
|
}
|
|
|
|
##' This function applies some specified string substitution such as
|
|
##' 's/log/exp/' so that from a list of "forward" transformation
|
|
##' functions it computes a "backward" one
|
|
##' @title Apply back transformations to cols of a data.frame
|
|
##' @param df named data.frame
|
|
##' @param tlist list of function names
|
|
##' @return data.frame with the columns specified in tlist and the
|
|
##' backtransformed numerical values
|
|
##' @author MDL
|
|
BackTransfList <- function (df, tlist) {
|
|
vars <- intersect(colnames(df), names(tlist))
|
|
lens <- sapply(tlist[vars], length)
|
|
n <- max(lens)
|
|
filledlist <- lapply(tlist[vars],
|
|
function(x)
|
|
if (length(x) < n)
|
|
return(c(x, rep("I", n - length(x))))
|
|
else
|
|
return(x))
|
|
rlist <- lapply(filledlist, rev)
|
|
rlist <- lapply(rlist, sub, pattern = "log", replacement = "exp")
|
|
rlist <- lapply(rlist, sub, pattern = "1p", replacement = "m1")
|
|
rlist <- lapply(rlist, sub, pattern = "Mul", replacement = "Div")
|
|
tmp <- df[, vars]
|
|
for (i in seq_len(n)) {
|
|
tmp <- as.data.frame(sapply(vars, function(x)
|
|
do.call(rlist[[x]][i], list(tmp[, x]))))
|
|
}
|
|
return(tmp)
|
|
}
|
|
|
|
tlist <- list("H" = "log1p", "O" = "log1p", "Charge" = "Safelog",
|
|
"Ba" = "log1p", "Cl" = "log1p", "S(6)" = "log1p",
|
|
"Sr" = "log1p", "Barite" = "log1p", "Celestite" = "log1p")
|
|
|
|
minmaxlog <- list(min = c(H = 4.71860987935512, O = 4.03435069501537,
|
|
Charge = -14.5337693764617, Ba = 1.87312878574393e-141,
|
|
Cl = 0, `S(6)` = 4.2422742065922e-07,
|
|
Sr = 0.00049370806741832, Barite = 0.000999043199940672,
|
|
Celestite = 0.218976382406766),
|
|
max = c(H = 4.71860988013054, O = 4.03439461483133,
|
|
Charge = 12.9230752782909, Ba = 0.086989273200771,
|
|
Cl = 0.178729802869381, `S(6)` = 0.000620582151722819,
|
|
Sr = 0.0543631841661421, Barite = 0.56348209786429,
|
|
Celestite = 0.69352422027466))
|
|
|
|
preprocess <- function(df) {
|
|
if (!is.data.frame(df))
|
|
df <- as.data.frame(df, check.names = FALSE)
|
|
tlog <- TransfList(df, tlist)
|
|
as.data.frame(lapply(colnames(tlog),
|
|
function(x) scale_min_max(x = tlog[x],
|
|
min = minmaxlog$min[x],
|
|
max = minmaxlog$max[x],
|
|
backtransform = FALSE)),
|
|
check.names = FALSE)
|
|
}
|
|
|
|
postprocess <- function(df) {
|
|
if (!is.data.frame(df))
|
|
df <- as.data.frame(df, check.names = FALSE)
|
|
ret <- as.data.frame(lapply(colnames(df),
|
|
function(x) scale_min_max(x = df[x],
|
|
min = minmaxlog$min[x],
|
|
max = minmaxlog$max[x],
|
|
backtransform = TRUE)),
|
|
check.names = FALSE)
|
|
|
|
BackTransfList(ret, tlist)
|
|
}
|
|
|
|
mass_balance <- function(x, y) {
|
|
dBa <- abs(y$Ba + y$Barite -
|
|
(x$Ba + x$Barite))
|
|
dSr <- abs(y$Sr + y$Celestite -
|
|
(x$Sr + x$Celestite))
|
|
return(dBa + dSr)
|
|
}
|
|
|
|
validate_predictions <- function(predictors, prediction) {
|
|
epsilon <- 1E-7
|
|
mb <- mass_balance(predictors, prediction)
|
|
msgm("Mass balance mean:", mean(mb))
|
|
msgm("Mass balance variance:", var(mb))
|
|
ret <- mb < epsilon
|
|
msgm("Rows where mass balance meets threshold", epsilon, ":",
|
|
sum(ret))
|
|
return(ret)
|
|
}
|