poet/bench/barite/barite_50ai_surr_mdl.R

150 lines
5.4 KiB
R

## Time-stamp: "Last modified 2024-05-30 13:27:06 delucia"
## Use the global variable "ai_surrogate_base_path" when using file paths
## relative to the input script
model_file_path <- normalizePath(paste0(ai_surrogate_base_path,
"barite_50ai_all.keras"))
scale_min_max <- function(x, min, max, backtransform) {
if (backtransform) {
return((x * (max - min)) + min)
} else {
return((x - min) / (max - min))
}
}
## Apply decimal logarithms handling negative values
Safelog <- function (vec) {
rv <- range(vec)
if (max(abs(rv)) < 1) {
ret <- vec
ret[vec == 0] <- 0
ret[vec > 0] <- log10(vec[vec > 0])
ret[vec < 0] <- -log10(-vec[vec < 0])
} else {
ret <- log10(abs(vec))
}
ret
}
Safeexp <- function (vec) {
ret <- vec
ret[vec == 0] <- 0
ret[vec > 0] <- -10^(-vec[vec > 0])
ret[vec < 0] <- 10^(vec[vec < 0])
ret
}
##' @title Apply transformations to cols of a data.frame
##' @param df named data.frame
##' @param tlist list of function names
##' @return data.frame with the columns specified in tlist and the
##' transformed numerical values
##' @author MDL
TransfList <- function (df, tlist) {
vars <- intersect(colnames(df), names(tlist))
lens <- sapply(tlist[vars], length)
n <- max(lens)
filledlist <- lapply(tlist[vars],
function(x)
if (length(x) < n)
return(c(x, rep("I", n - length(x))))
else
return(x))
tmp <- df[, vars]
for (i in seq_len(n)) {
tmp <- as.data.frame(sapply(vars, function(x)
do.call(filledlist[[x]][i], list(tmp[, x]))))
}
return(tmp)
}
##' This function applies some specified string substitution such as
##' 's/log/exp/' so that from a list of "forward" transformation
##' functions it computes a "backward" one
##' @title Apply back transformations to cols of a data.frame
##' @param df named data.frame
##' @param tlist list of function names
##' @return data.frame with the columns specified in tlist and the
##' backtransformed numerical values
##' @author MDL
BackTransfList <- function (df, tlist) {
vars <- intersect(colnames(df), names(tlist))
lens <- sapply(tlist[vars], length)
n <- max(lens)
filledlist <- lapply(tlist[vars],
function(x)
if (length(x) < n)
return(c(x, rep("I", n - length(x))))
else
return(x))
rlist <- lapply(filledlist, rev)
rlist <- lapply(rlist, sub, pattern = "log", replacement = "exp")
rlist <- lapply(rlist, sub, pattern = "1p", replacement = "m1")
rlist <- lapply(rlist, sub, pattern = "Mul", replacement = "Div")
tmp <- df[, vars]
for (i in seq_len(n)) {
tmp <- as.data.frame(sapply(vars, function(x)
do.call(rlist[[x]][i], list(tmp[, x]))))
}
return(tmp)
}
tlist <- list("H" = "log1p", "O" = "log1p", "Charge" = "Safelog",
"Ba" = "log1p", "Cl" = "log1p", "S(6)" = "log1p",
"Sr" = "log1p", "Barite" = "log1p", "Celestite" = "log1p")
minmaxlog <- list(min = c(H = 4.71860987935512, O = 4.03435069501537,
Charge = -14.5337693764617, Ba = 1.87312878574393e-141,
Cl = 0, `S(6)` = 4.2422742065922e-07,
Sr = 0.00049370806741832, Barite = 0.000999043199940672,
Celestite = 0.218976382406766),
max = c(H = 4.71860988013054, O = 4.03439461483133,
Charge = 12.9230752782909, Ba = 0.086989273200771,
Cl = 0.178729802869381, `S(6)` = 0.000620582151722819,
Sr = 0.0543631841661421, Barite = 0.56348209786429,
Celestite = 0.69352422027466))
preprocess <- function(df) {
if (!is.data.frame(df))
df <- as.data.frame(df, check.names = FALSE)
tlog <- TransfList(df, tlist)
as.data.frame(lapply(colnames(tlog),
function(x) scale_min_max(x = tlog[x],
min = minmaxlog$min[x],
max = minmaxlog$max[x],
backtransform = FALSE)),
check.names = FALSE)
}
postprocess <- function(df) {
if (!is.data.frame(df))
df <- as.data.frame(df, check.names = FALSE)
ret <- as.data.frame(lapply(colnames(df),
function(x) scale_min_max(x = df[x],
min = minmaxlog$min[x],
max = minmaxlog$max[x],
backtransform = TRUE)),
check.names = FALSE)
BackTransfList(ret, tlist)
}
mass_balance <- function(x, y) {
dBa <- abs(y$Ba + y$Barite -
(x$Ba + x$Barite))
dSr <- abs(y$Sr + y$Celestite -
(x$Sr + x$Celestite))
return(dBa + dSr)
}
validate_predictions <- function(predictors, prediction) {
epsilon <- 1E-7
mb <- mass_balance(predictors, prediction)
msgm("Mass balance mean:", mean(mb))
msgm("Mass balance variance:", var(mb))
ret <- mb < epsilon
msgm("Rows where mass balance meets threshold", epsilon, ":",
sum(ret))
return(ret)
}