mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-16 12:54:50 +01:00
Merge branch 'origin/ai-surrogate-v03-temp-mdl' into ai_surrogate_merge
This commit is contained in:
commit
91b99066be
@ -5,8 +5,8 @@
|
|||||||
## in the variable "ai_surrogate_input_script". See the barite_200.R file as an
|
## in the variable "ai_surrogate_input_script". See the barite_200.R file as an
|
||||||
## example and the general README for more information.
|
## example and the general README for more information.
|
||||||
|
|
||||||
library(keras)
|
## library(keras3)
|
||||||
library(tensorflow)
|
## library(tensorflow)
|
||||||
|
|
||||||
initiate_model <- function() {
|
initiate_model <- function() {
|
||||||
hidden_layers <- c(48, 96, 24)
|
hidden_layers <- c(48, 96, 24)
|
||||||
@ -54,6 +54,10 @@ preprocess <- function(df, backtransform = FALSE, outputs = FALSE) {
|
|||||||
return(df)
|
return(df)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
postprocess <- function(df, backtransform = TRUE, outputs = TRUE) {
|
||||||
|
return(df)
|
||||||
|
}
|
||||||
|
|
||||||
set_valid_predictions <- function(temp_field, prediction, validity) {
|
set_valid_predictions <- function(temp_field, prediction, validity) {
|
||||||
temp_field[validity == 1, ] <- prediction[validity == 1, ]
|
temp_field[validity == 1, ] <- prediction[validity == 1, ]
|
||||||
return(temp_field)
|
return(temp_field)
|
||||||
|
|||||||
@ -38,7 +38,7 @@ mass_balance <- function(predictors, prediction) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
validate_predictions <- function(predictors, prediction) {
|
validate_predictions <- function(predictors, prediction) {
|
||||||
epsilon <- 0.000000003
|
epsilon <- 3e-5
|
||||||
mb <- mass_balance(predictors, prediction)
|
mb <- mass_balance(predictors, prediction)
|
||||||
msgm("Mass balance mean:", mean(mb))
|
msgm("Mass balance mean:", mean(mb))
|
||||||
msgm("Mass balance variance:", var(mb))
|
msgm("Mass balance variance:", var(mb))
|
||||||
|
|||||||
60
bench/barite/barite_50ai.R
Normal file
60
bench/barite/barite_50ai.R
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
## Time-stamp: "Last modified 2024-05-30 13:34:14 delucia"
|
||||||
|
cols <- 50
|
||||||
|
rows <- 50
|
||||||
|
|
||||||
|
s_cols <- 0.25
|
||||||
|
s_rows <- 0.25
|
||||||
|
|
||||||
|
grid_def <- matrix(2, nrow = rows, ncol = cols)
|
||||||
|
|
||||||
|
# Define grid configuration for POET model
|
||||||
|
grid_setup <- list(
|
||||||
|
pqc_in_file = "./barite.pqi",
|
||||||
|
pqc_db_file = "./db_barite.dat", ## Path to the database file for Phreeqc
|
||||||
|
grid_def = grid_def, ## Definition of the grid, containing IDs according to the Phreeqc input script
|
||||||
|
grid_size = c(s_rows, s_cols), ## Size of the grid in meters
|
||||||
|
constant_cells = c() ## IDs of cells with constant concentration
|
||||||
|
)
|
||||||
|
|
||||||
|
bound_length <- 2
|
||||||
|
|
||||||
|
bound_def <- list(
|
||||||
|
"type" = rep("constant", bound_length),
|
||||||
|
"sol_id" = rep(3, bound_length),
|
||||||
|
"cell" = seq(1, bound_length)
|
||||||
|
)
|
||||||
|
|
||||||
|
homogenous_alpha <- 1e-8
|
||||||
|
|
||||||
|
diffusion_setup <- list(
|
||||||
|
boundaries = list(
|
||||||
|
"W" = bound_def,
|
||||||
|
"N" = bound_def
|
||||||
|
),
|
||||||
|
alpha_x = homogenous_alpha,
|
||||||
|
alpha_y = homogenous_alpha
|
||||||
|
)
|
||||||
|
|
||||||
|
dht_species <- c(
|
||||||
|
"H" = 4,
|
||||||
|
"O" = 10,
|
||||||
|
"Charge" = 4,
|
||||||
|
"Ba" = 7,
|
||||||
|
"Cl" = 4,
|
||||||
|
"S(6)" = 7,
|
||||||
|
"Sr" = 4,
|
||||||
|
"Barite" = 2,
|
||||||
|
"Celestite" = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
chemistry_setup <- list(
|
||||||
|
dht_species = dht_species,
|
||||||
|
ai_surrogate_input_script = "./barite_50ai_surr_mdl.R"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define a setup list for simulation configuration
|
||||||
|
setup <- list(
|
||||||
|
Grid = grid_setup, # Parameters related to the grid structure
|
||||||
|
Diffusion = diffusion_setup, # Parameters related to the diffusion process
|
||||||
|
Chemistry = chemistry_setup
|
||||||
|
)
|
||||||
BIN
bench/barite/barite_50ai.rds
Normal file
BIN
bench/barite/barite_50ai.rds
Normal file
Binary file not shown.
BIN
bench/barite/barite_50ai_all.keras
Normal file
BIN
bench/barite/barite_50ai_all.keras
Normal file
Binary file not shown.
9
bench/barite/barite_50ai_rt.R
Normal file
9
bench/barite/barite_50ai_rt.R
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
iterations <- 1000
|
||||||
|
|
||||||
|
dt <- 200
|
||||||
|
|
||||||
|
list(
|
||||||
|
timesteps = rep(dt, iterations),
|
||||||
|
store_result = TRUE,
|
||||||
|
out_save = c(1, 5, seq(20, iterations, by=20))
|
||||||
|
)
|
||||||
90
bench/barite/barite_50ai_surr_mdl.R
Normal file
90
bench/barite/barite_50ai_surr_mdl.R
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
## Time-stamp: "Last modified 2024-05-30 13:27:06 delucia"
|
||||||
|
|
||||||
|
## load a pretrained model from tensorflow file
|
||||||
|
## Use the global variable "ai_surrogate_base_path" when using file paths
|
||||||
|
## relative to the input script
|
||||||
|
initiate_model <- function() {
|
||||||
|
require(keras3)
|
||||||
|
require(tensorflow)
|
||||||
|
init_model <- normalizePath(paste0(ai_surrogate_base_path,
|
||||||
|
"barite_50ai_all.keras"))
|
||||||
|
Model <- keras3::load_model(init_model)
|
||||||
|
msgm("Loaded model:")
|
||||||
|
print(str(Model))
|
||||||
|
return(Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
scale_min_max <- function(x, min, max, backtransform) {
|
||||||
|
if (backtransform) {
|
||||||
|
return((x * (max - min)) + min)
|
||||||
|
} else {
|
||||||
|
return((x - min) / (max - min))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
minmax <- list(min = c(H = 111.012433592824, O = 55.5062185549492, Charge = -3.1028354471876e-08,
|
||||||
|
Ba = 1.87312878574393e-141, Cl = 0, `S(6)` = 4.24227510643685e-07,
|
||||||
|
Sr = 0.00049382996130541, Barite = 0.000999542409828586, Celestite = 0.244801877115968),
|
||||||
|
max = c(H = 111.012433679682, O = 55.5087003521685, Charge = 5.27666636082035e-07,
|
||||||
|
Ba = 0.0908849779513762, Cl = 0.195697626449355, `S(6)` = 0.000620774752665846,
|
||||||
|
Sr = 0.0558680070692722, Barite = 0.756779139057097, Celestite = 1.00075422160624
|
||||||
|
))
|
||||||
|
|
||||||
|
preprocess <- function(df) {
|
||||||
|
if (!is.data.frame(df))
|
||||||
|
df <- as.data.frame(df, check.names = FALSE)
|
||||||
|
|
||||||
|
as.data.frame(lapply(colnames(df),
|
||||||
|
function(x) scale_min_max(x=df[x],
|
||||||
|
min=minmax$min[x],
|
||||||
|
max=minmax$max[x],
|
||||||
|
backtransform=FALSE)),
|
||||||
|
check.names = FALSE)
|
||||||
|
}
|
||||||
|
|
||||||
|
postprocess <- function(df) {
|
||||||
|
if (!is.data.frame(df))
|
||||||
|
df <- as.data.frame(df, check.names = FALSE)
|
||||||
|
|
||||||
|
as.data.frame(lapply(colnames(df),
|
||||||
|
function(x) scale_min_max(x=df[x],
|
||||||
|
min=minmax$min[x],
|
||||||
|
max=minmax$max[x],
|
||||||
|
backtransform=TRUE)),
|
||||||
|
check.names = FALSE)
|
||||||
|
}
|
||||||
|
|
||||||
|
mass_balance <- function(predictors, prediction) {
|
||||||
|
dBa <- abs(prediction$Ba + prediction$Barite -
|
||||||
|
predictors$Ba - predictors$Barite)
|
||||||
|
dSr <- abs(prediction$Sr + prediction$Celestite -
|
||||||
|
predictors$Sr - predictors$Celestite)
|
||||||
|
return(dBa + dSr)
|
||||||
|
}
|
||||||
|
|
||||||
|
validate_predictions <- function(predictors, prediction) {
|
||||||
|
epsilon <- 1E-7
|
||||||
|
mb <- mass_balance(predictors, prediction)
|
||||||
|
msgm("Mass balance mean:", mean(mb))
|
||||||
|
msgm("Mass balance variance:", var(mb))
|
||||||
|
ret <- mb < epsilon
|
||||||
|
msgm("Rows where mass balance meets threshold", epsilon, ":",
|
||||||
|
sum(ret))
|
||||||
|
return(ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
training_step <- function(model, predictor, target, validity) {
|
||||||
|
msgm("Starting incremental training:")
|
||||||
|
|
||||||
|
## x <- as.matrix(predictor)
|
||||||
|
## y <- as.matrix(target[colnames(x)])
|
||||||
|
|
||||||
|
history <- model %>% keras3::fit(x = data.matrix(predictor),
|
||||||
|
y = data.matrix(target),
|
||||||
|
epochs = 10, verbose=1)
|
||||||
|
|
||||||
|
keras3::save_model(model,
|
||||||
|
filepath = paste0(out_dir, "/current_model.keras"),
|
||||||
|
overwrite=TRUE)
|
||||||
|
return(model)
|
||||||
|
}
|
||||||
@ -48,7 +48,7 @@ void poet::ChemistryModule::WorkerLoop() {
|
|||||||
case CHEM_FIELD_INIT: {
|
case CHEM_FIELD_INIT: {
|
||||||
ChemBCast(&this->prop_count, 1, MPI_UINT32_T);
|
ChemBCast(&this->prop_count, 1, MPI_UINT32_T);
|
||||||
if (this->ai_surrogate_enabled) {
|
if (this->ai_surrogate_enabled) {
|
||||||
this->ai_surrogate_validity_vector.reserve(this->n_cells);
|
this->ai_surrogate_validity_vector.resize(this->n_cells); // resize statt reserve?
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -152,8 +152,8 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
|||||||
// current simulation time ('age' of simulation)
|
// current simulation time ('age' of simulation)
|
||||||
current_sim_time = mpi_buffer[count + 3];
|
current_sim_time = mpi_buffer[count + 3];
|
||||||
|
|
||||||
/* 4th double value is currently a placeholder */
|
// current work package start location in field
|
||||||
// placeholder = mpi_buffer[count+4];
|
wp_start_index = mpi_buffer[count + 4];
|
||||||
|
|
||||||
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
|
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
|
||||||
s_curr_wp.input[wp_i] =
|
s_curr_wp.input[wp_i] =
|
||||||
|
|||||||
46
src/poet.cpp
46
src/poet.cpp
@ -286,25 +286,36 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
|||||||
std::to_string(chem.getField().GetRequestedVecSize()) + ")), TMP_PROPS)"));
|
std::to_string(chem.getField().GetRequestedVecSize()) + ")), TMP_PROPS)"));
|
||||||
R.parseEval("predictors <- predictors[ai_surrogate_species]");
|
R.parseEval("predictors <- predictors[ai_surrogate_species]");
|
||||||
|
|
||||||
// Predict
|
// Apply preprocessing
|
||||||
|
MSG("AI Preprocessing");
|
||||||
R.parseEval("predictors_scaled <- preprocess(predictors)");
|
R.parseEval("predictors_scaled <- preprocess(predictors)");
|
||||||
|
|
||||||
R.parseEval("prediction <- preprocess(prediction_step(model, predictors_scaled),\
|
// Predict
|
||||||
backtransform = TRUE,\
|
MSG("AI Predict");
|
||||||
outputs = TRUE)");
|
R.parseEval("aipreds_scaled <- prediction_step(model, predictors_scaled)");
|
||||||
|
|
||||||
|
// Apply postprocessing
|
||||||
|
MSG("AI Postprocesing");
|
||||||
|
R.parseEval("aipreds <- postprocess(aipreds_scaled)");
|
||||||
|
|
||||||
// Validate prediction and write valid predictions to chem field
|
// Validate prediction and write valid predictions to chem field
|
||||||
R.parseEval("validity_vector <- validate_predictions(predictors,\
|
MSG("AI Validate");
|
||||||
prediction)");
|
R.parseEval("validity_vector <- validate_predictions(predictors, aipreds)");
|
||||||
|
|
||||||
|
MSG("AI Marking accepted");
|
||||||
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
|
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
|
||||||
|
|
||||||
|
MSG("AI TempField");
|
||||||
std::vector<std::vector<double>> RTempField = R.parseEval("set_valid_predictions(predictors,\
|
std::vector<std::vector<double>> RTempField = R.parseEval("set_valid_predictions(predictors,\
|
||||||
prediction,\
|
aipreds,\
|
||||||
validity_vector)");
|
validity_vector)");
|
||||||
|
|
||||||
|
MSG("AI Set Field");
|
||||||
Field predictions_field = Field(R.parseEval("nrow(predictors)"),
|
Field predictions_field = Field(R.parseEval("nrow(predictors)"),
|
||||||
RTempField,
|
RTempField,
|
||||||
R.parseEval("names(predictors)"));
|
R.parseEval("colnames(predictors)"));
|
||||||
|
|
||||||
|
MSG("AI Update");
|
||||||
chem.getField().update(predictions_field);
|
chem.getField().update(predictions_field);
|
||||||
double ai_end_t = MPI_Wtime();
|
double ai_end_t = MPI_Wtime();
|
||||||
R["ai_prediction_time"] = ai_end_t - ai_start_t;
|
R["ai_prediction_time"] = ai_end_t - ai_start_t;
|
||||||
@ -323,9 +334,10 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
|||||||
R.parseEval("targets <- targets[ai_surrogate_species]");
|
R.parseEval("targets <- targets[ai_surrogate_species]");
|
||||||
|
|
||||||
// TODO: Check how to get the correct columns
|
// TODO: Check how to get the correct columns
|
||||||
R.parseEval("target_scaled <- preprocess(targets, outputs = TRUE)");
|
R.parseEval("target_scaled <- preprocess(targets)");
|
||||||
|
|
||||||
R.parseEval("training_step(model, predictors_scaled, target_scaled, validity_vector)");
|
MSG("AI: incremental training");
|
||||||
|
R.parseEval("model <- training_step(model, predictors_scaled, target_scaled, validity_vector)");
|
||||||
double ai_end_t = MPI_Wtime();
|
double ai_end_t = MPI_Wtime();
|
||||||
R["ai_training_time"] = ai_end_t - ai_start_t;
|
R["ai_training_time"] = ai_end_t - ai_start_t;
|
||||||
}
|
}
|
||||||
@ -464,14 +476,14 @@ int main(int argc, char *argv[]) {
|
|||||||
|
|
||||||
const std::string ai_surrogate_input_script = init_list.getChemistryInit().ai_surrogate_input_script;
|
const std::string ai_surrogate_input_script = init_list.getChemistryInit().ai_surrogate_input_script;
|
||||||
|
|
||||||
if (!ai_surrogate_input_script.empty()) {
|
if (!ai_surrogate_input_script_path.empty()) {
|
||||||
/* Incorporate user defined ai surrogate input script */
|
R["ai_surrogate_base_path"] = ai_surrogate_input_script_path.substr(0, ai_surrogate_input_script_path.find_last_of('/') + 1);
|
||||||
R.parseEvalQ(ai_surrogate_input_script);
|
|
||||||
|
|
||||||
std::string ai_surrogate_base_path = R["ai_surrogate_base_path"];
|
MSG("AI: sourcing user-provided script");
|
||||||
R["ai_surrogate_base_path"] = ai_surrogate_base_path.substr(0, ai_surrogate_base_path.find_last_of('/') + 1);
|
R.parseEvalQ("source('" + ai_surrogate_input_script_path + "')");
|
||||||
}
|
}
|
||||||
R.parseEval("model <- initiate_model()");
|
MSG("AI: initialize AI model");
|
||||||
|
R.parseEval("model <- initiate_model()");
|
||||||
R.parseEval("gpu_info()");
|
R.parseEval("gpu_info()");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -493,7 +505,7 @@ int main(int argc, char *argv[]) {
|
|||||||
|
|
||||||
string r_vis_code;
|
string r_vis_code;
|
||||||
r_vis_code =
|
r_vis_code =
|
||||||
"saveRDS(profiling, file=paste0(setup$out_dir,'/timings.rds'));";
|
"saveRDS(profiling, file=paste0(setup$out_dir,'/timings.rds'));";
|
||||||
R.parseEval(r_vis_code);
|
R.parseEval(r_vis_code);
|
||||||
|
|
||||||
MSG("Done! Results are stored as R objects into <" + run_params.out_dir +
|
MSG("Done! Results are stored as R objects into <" + run_params.out_dir +
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user