MDL: AI model seems correctly updated/stored

This commit is contained in:
Marco De Lucia 2024-05-30 13:37:28 +02:00
parent 5b48ccb21b
commit 59c3eb9fbd
4 changed files with 28 additions and 19 deletions

View File

@ -1,4 +1,4 @@
## Time-stamp: "Last modified 2024-05-29 10:51:35 delucia" ## Time-stamp: "Last modified 2024-05-30 13:34:14 delucia"
cols <- 50 cols <- 50
rows <- 50 rows <- 50
@ -10,10 +10,10 @@ grid_def <- matrix(2, nrow = rows, ncol = cols)
# Define grid configuration for POET model # Define grid configuration for POET model
grid_setup <- list( grid_setup <- list(
pqc_in_file = "./barite.pqi", pqc_in_file = "./barite.pqi",
pqc_db_file = "./db_barite.dat", # Path to the database file for Phreeqc pqc_db_file = "./db_barite.dat", ## Path to the database file for Phreeqc
grid_def = grid_def, # Definition of the grid, containing IDs according to the Phreeqc input script grid_def = grid_def, ## Definition of the grid, containing IDs according to the Phreeqc input script
grid_size = c(s_rows, s_cols), # Size of the grid in meters grid_size = c(s_rows, s_cols), ## Size of the grid in meters
constant_cells = c() # IDs of cells with constant concentration constant_cells = c() ## IDs of cells with constant concentration
) )
bound_length <- 2 bound_length <- 2
@ -37,14 +37,14 @@ diffusion_setup <- list(
dht_species <- c( dht_species <- c(
"H" = 4, "H" = 4,
"O" = 9, "O" = 10,
"Charge" = 4, "Charge" = 4,
"Ba" = 4, "Ba" = 7,
"Cl" = 4, "Cl" = 4,
"S(6)" = 7, "S(6)" = 7,
"Sr" = 4, "Sr" = 4,
"Barite" = 7, "Barite" = 2,
"Celestite" = 7 "Celestite" = 2
) )
chemistry_setup <- list( chemistry_setup <- list(

Binary file not shown.

View File

@ -1,4 +1,4 @@
## Time-stamp: "Last modified 2024-05-30 11:16:57 delucia" ## Time-stamp: "Last modified 2024-05-30 13:27:06 delucia"
## load a pretrained model from tensorflow file ## load a pretrained model from tensorflow file
## Use the global variable "ai_surrogate_base_path" when using file paths ## Use the global variable "ai_surrogate_base_path" when using file paths
@ -74,12 +74,17 @@ validate_predictions <- function(predictors, prediction) {
} }
training_step <- function(model, predictor, target, validity) { training_step <- function(model, predictor, target, validity) {
msgm("Training:") msgm("Starting incremental training:")
x <- as.matrix(predictor) ## x <- as.matrix(predictor)
y <- as.matrix(target[colnames(x)]) ## y <- as.matrix(target[colnames(x)])
model %>% keras3::fit(x, y) history <- model %>% keras3::fit(x = data.matrix(predictor),
y = data.matrix(target),
epochs = 10, verbose=1)
model %>% keras3::save_model(paste0(out_dir, "/current_model.keras"), overwrite=TRUE) keras3::save_model(model,
filepath = paste0(out_dir, "/current_model.keras"),
overwrite=TRUE)
return(model)
} }

View File

@ -336,7 +336,8 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
// TODO: Check how to get the correct columns // TODO: Check how to get the correct columns
R.parseEval("target_scaled <- preprocess(targets)"); R.parseEval("target_scaled <- preprocess(targets)");
R.parseEval("training_step(model, predictors_scaled, target_scaled, validity_vector)"); MSG("AI: incremental training");
R.parseEval("model <- training_step(model, predictors_scaled, target_scaled, validity_vector)");
double ai_end_t = MPI_Wtime(); double ai_end_t = MPI_Wtime();
R["ai_training_time"] = ai_end_t - ai_start_t; R["ai_training_time"] = ai_end_t - ai_start_t;
} }
@ -477,9 +478,12 @@ int main(int argc, char *argv[]) {
if (!ai_surrogate_input_script_path.empty()) { if (!ai_surrogate_input_script_path.empty()) {
R["ai_surrogate_base_path"] = ai_surrogate_input_script_path.substr(0, ai_surrogate_input_script_path.find_last_of('/') + 1); R["ai_surrogate_base_path"] = ai_surrogate_input_script_path.substr(0, ai_surrogate_input_script_path.find_last_of('/') + 1);
R.parseEvalQ("source('" + ai_surrogate_input_script_path + "')");
MSG("AI: sourcing user-provided script");
R.parseEvalQ("source('" + ai_surrogate_input_script_path + "')");
} }
R.parseEval("model <- initiate_model()"); MSG("AI: initialize AI model");
R.parseEval("model <- initiate_model()");
R.parseEval("gpu_info()"); R.parseEval("gpu_info()");
} }
@ -501,7 +505,7 @@ int main(int argc, char *argv[]) {
string r_vis_code; string r_vis_code;
r_vis_code = r_vis_code =
"saveRDS(profiling, file=paste0(setup$out_dir,'/timings.rds'));"; "saveRDS(profiling, file=paste0(setup$out_dir,'/timings.rds'));";
R.parseEval(r_vis_code); R.parseEval(r_vis_code);
MSG("Done! Results are stored as R objects into <" + run_params.out_dir + MSG("Done! Results are stored as R objects into <" + run_params.out_dir +