mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-15 20:38:23 +01:00
MDL: AI model seems correctly updated/stored
This commit is contained in:
parent
b974b96d27
commit
7f522157d1
@ -1,4 +1,4 @@
|
||||
## Time-stamp: "Last modified 2024-05-29 10:51:35 delucia"
|
||||
## Time-stamp: "Last modified 2024-05-30 13:34:14 delucia"
|
||||
cols <- 50
|
||||
rows <- 50
|
||||
|
||||
@ -10,10 +10,10 @@ grid_def <- matrix(2, nrow = rows, ncol = cols)
|
||||
# Define grid configuration for POET model
|
||||
grid_setup <- list(
|
||||
pqc_in_file = "./barite.pqi",
|
||||
pqc_db_file = "./db_barite.dat", # Path to the database file for Phreeqc
|
||||
grid_def = grid_def, # Definition of the grid, containing IDs according to the Phreeqc input script
|
||||
grid_size = c(s_rows, s_cols), # Size of the grid in meters
|
||||
constant_cells = c() # IDs of cells with constant concentration
|
||||
pqc_db_file = "./db_barite.dat", ## Path to the database file for Phreeqc
|
||||
grid_def = grid_def, ## Definition of the grid, containing IDs according to the Phreeqc input script
|
||||
grid_size = c(s_rows, s_cols), ## Size of the grid in meters
|
||||
constant_cells = c() ## IDs of cells with constant concentration
|
||||
)
|
||||
|
||||
bound_length <- 2
|
||||
@ -37,14 +37,14 @@ diffusion_setup <- list(
|
||||
|
||||
dht_species <- c(
|
||||
"H" = 4,
|
||||
"O" = 9,
|
||||
"O" = 10,
|
||||
"Charge" = 4,
|
||||
"Ba" = 4,
|
||||
"Ba" = 7,
|
||||
"Cl" = 4,
|
||||
"S(6)" = 7,
|
||||
"Sr" = 4,
|
||||
"Barite" = 7,
|
||||
"Celestite" = 7
|
||||
"Barite" = 2,
|
||||
"Celestite" = 2
|
||||
)
|
||||
|
||||
chemistry_setup <- list(
|
||||
|
||||
BIN
bench/barite/barite_50ai.rds
Normal file
BIN
bench/barite/barite_50ai.rds
Normal file
Binary file not shown.
@ -1,4 +1,4 @@
|
||||
## Time-stamp: "Last modified 2024-05-30 11:16:57 delucia"
|
||||
## Time-stamp: "Last modified 2024-05-30 13:27:06 delucia"
|
||||
|
||||
## load a pretrained model from tensorflow file
|
||||
## Use the global variable "ai_surrogate_base_path" when using file paths
|
||||
@ -74,12 +74,17 @@ validate_predictions <- function(predictors, prediction) {
|
||||
}
|
||||
|
||||
training_step <- function(model, predictor, target, validity) {
|
||||
msgm("Training:")
|
||||
msgm("Starting incremental training:")
|
||||
|
||||
x <- as.matrix(predictor)
|
||||
y <- as.matrix(target[colnames(x)])
|
||||
## x <- as.matrix(predictor)
|
||||
## y <- as.matrix(target[colnames(x)])
|
||||
|
||||
model %>% keras3::fit(x, y)
|
||||
history <- model %>% keras3::fit(x = data.matrix(predictor),
|
||||
y = data.matrix(target),
|
||||
epochs = 10, verbose=1)
|
||||
|
||||
model %>% keras3::save_model(paste0(out_dir, "/current_model.keras"), overwrite=TRUE)
|
||||
keras3::save_model(model,
|
||||
filepath = paste0(out_dir, "/current_model.keras"),
|
||||
overwrite=TRUE)
|
||||
return(model)
|
||||
}
|
||||
|
||||
12
src/poet.cpp
12
src/poet.cpp
@ -336,7 +336,8 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
// TODO: Check how to get the correct columns
|
||||
R.parseEval("target_scaled <- preprocess(targets)");
|
||||
|
||||
R.parseEval("training_step(model, predictors_scaled, target_scaled, validity_vector)");
|
||||
MSG("AI: incremental training");
|
||||
R.parseEval("model <- training_step(model, predictors_scaled, target_scaled, validity_vector)");
|
||||
double ai_end_t = MPI_Wtime();
|
||||
R["ai_training_time"] = ai_end_t - ai_start_t;
|
||||
}
|
||||
@ -477,9 +478,12 @@ int main(int argc, char *argv[]) {
|
||||
|
||||
if (!ai_surrogate_input_script_path.empty()) {
|
||||
R["ai_surrogate_base_path"] = ai_surrogate_input_script_path.substr(0, ai_surrogate_input_script_path.find_last_of('/') + 1);
|
||||
R.parseEvalQ("source('" + ai_surrogate_input_script_path + "')");
|
||||
|
||||
MSG("AI: sourcing user-provided script");
|
||||
R.parseEvalQ("source('" + ai_surrogate_input_script_path + "')");
|
||||
}
|
||||
R.parseEval("model <- initiate_model()");
|
||||
MSG("AI: initialize AI model");
|
||||
R.parseEval("model <- initiate_model()");
|
||||
R.parseEval("gpu_info()");
|
||||
}
|
||||
|
||||
@ -501,7 +505,7 @@ int main(int argc, char *argv[]) {
|
||||
|
||||
string r_vis_code;
|
||||
r_vis_code =
|
||||
"saveRDS(profiling, file=paste0(setup$out_dir,'/timings.rds'));";
|
||||
"saveRDS(profiling, file=paste0(setup$out_dir,'/timings.rds'));";
|
||||
R.parseEval(r_vis_code);
|
||||
|
||||
MSG("Done! Results are stored as R objects into <" + run_params.out_dir +
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user