diff --git a/bench/barite/barite_50ai.R b/bench/barite/barite_50ai.R index 3944724a3..c2a674a85 100644 --- a/bench/barite/barite_50ai.R +++ b/bench/barite/barite_50ai.R @@ -1,4 +1,4 @@ -## Time-stamp: "Last modified 2024-05-29 10:51:35 delucia" +## Time-stamp: "Last modified 2024-05-30 13:34:14 delucia" cols <- 50 rows <- 50 @@ -10,10 +10,10 @@ grid_def <- matrix(2, nrow = rows, ncol = cols) # Define grid configuration for POET model grid_setup <- list( pqc_in_file = "./barite.pqi", - pqc_db_file = "./db_barite.dat", # Path to the database file for Phreeqc - grid_def = grid_def, # Definition of the grid, containing IDs according to the Phreeqc input script - grid_size = c(s_rows, s_cols), # Size of the grid in meters - constant_cells = c() # IDs of cells with constant concentration + pqc_db_file = "./db_barite.dat", ## Path to the database file for Phreeqc + grid_def = grid_def, ## Definition of the grid, containing IDs according to the Phreeqc input script + grid_size = c(s_rows, s_cols), ## Size of the grid in meters + constant_cells = c() ## IDs of cells with constant concentration ) bound_length <- 2 @@ -37,14 +37,14 @@ diffusion_setup <- list( dht_species <- c( "H" = 4, - "O" = 9, + "O" = 10, "Charge" = 4, - "Ba" = 4, + "Ba" = 7, "Cl" = 4, "S(6)" = 7, "Sr" = 4, - "Barite" = 7, - "Celestite" = 7 + "Barite" = 2, + "Celestite" = 2 ) chemistry_setup <- list( diff --git a/bench/barite/barite_50ai.rds b/bench/barite/barite_50ai.rds new file mode 100644 index 000000000..efc230f27 Binary files /dev/null and b/bench/barite/barite_50ai.rds differ diff --git a/bench/barite/barite_50ai_surr_mdl.R b/bench/barite/barite_50ai_surr_mdl.R index 01de5113b..237d5a0cd 100644 --- a/bench/barite/barite_50ai_surr_mdl.R +++ b/bench/barite/barite_50ai_surr_mdl.R @@ -1,4 +1,4 @@ -## Time-stamp: "Last modified 2024-05-30 11:16:57 delucia" +## Time-stamp: "Last modified 2024-05-30 13:27:06 delucia" ## load a pretrained model from tensorflow file ## Use the global variable "ai_surrogate_base_path" when using file paths @@ -74,12 +74,17 @@ validate_predictions <- function(predictors, prediction) { } training_step <- function(model, predictor, target, validity) { - msgm("Training:") + msgm("Starting incremental training:") - x <- as.matrix(predictor) - y <- as.matrix(target[colnames(x)]) + ## x <- as.matrix(predictor) + ## y <- as.matrix(target[colnames(x)]) - model %>% keras3::fit(x, y) + history <- model %>% keras3::fit(x = data.matrix(predictor), + y = data.matrix(target), + epochs = 10, verbose=1) - model %>% keras3::save_model(paste0(out_dir, "/current_model.keras"), overwrite=TRUE) + keras3::save_model(model, + filepath = paste0(out_dir, "/current_model.keras"), + overwrite=TRUE) + return(model) } diff --git a/src/poet.cpp b/src/poet.cpp index 5175732a5..4cbaa709e 100644 --- a/src/poet.cpp +++ b/src/poet.cpp @@ -336,7 +336,8 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms, // TODO: Check how to get the correct columns R.parseEval("target_scaled <- preprocess(targets)"); - R.parseEval("training_step(model, predictors_scaled, target_scaled, validity_vector)"); + MSG("AI: incremental training"); + R.parseEval("model <- training_step(model, predictors_scaled, target_scaled, validity_vector)"); double ai_end_t = MPI_Wtime(); R["ai_training_time"] = ai_end_t - ai_start_t; } @@ -477,9 +478,12 @@ int main(int argc, char *argv[]) { if (!ai_surrogate_input_script_path.empty()) { R["ai_surrogate_base_path"] = ai_surrogate_input_script_path.substr(0, ai_surrogate_input_script_path.find_last_of('/') + 1); - R.parseEvalQ("source('" + ai_surrogate_input_script_path + "')"); + + MSG("AI: sourcing user-provided script"); + R.parseEvalQ("source('" + ai_surrogate_input_script_path + "')"); } - R.parseEval("model <- initiate_model()"); + MSG("AI: initialize AI model"); + R.parseEval("model <- initiate_model()"); R.parseEval("gpu_info()"); } @@ -501,7 +505,7 @@ int main(int argc, char *argv[]) { string r_vis_code; r_vis_code = - "saveRDS(profiling, file=paste0(setup$out_dir,'/timings.rds'));"; + "saveRDS(profiling, file=paste0(setup$out_dir,'/timings.rds'));"; R.parseEval(r_vis_code); MSG("Done! Results are stored as R objects into <" + run_params.out_dir +