mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-16 12:54:50 +01:00
MDL: AI model seems correctly updated/stored
This commit is contained in:
parent
b974b96d27
commit
7f522157d1
@ -1,4 +1,4 @@
|
|||||||
## Time-stamp: "Last modified 2024-05-29 10:51:35 delucia"
|
## Time-stamp: "Last modified 2024-05-30 13:34:14 delucia"
|
||||||
cols <- 50
|
cols <- 50
|
||||||
rows <- 50
|
rows <- 50
|
||||||
|
|
||||||
@ -10,10 +10,10 @@ grid_def <- matrix(2, nrow = rows, ncol = cols)
|
|||||||
# Define grid configuration for POET model
|
# Define grid configuration for POET model
|
||||||
grid_setup <- list(
|
grid_setup <- list(
|
||||||
pqc_in_file = "./barite.pqi",
|
pqc_in_file = "./barite.pqi",
|
||||||
pqc_db_file = "./db_barite.dat", # Path to the database file for Phreeqc
|
pqc_db_file = "./db_barite.dat", ## Path to the database file for Phreeqc
|
||||||
grid_def = grid_def, # Definition of the grid, containing IDs according to the Phreeqc input script
|
grid_def = grid_def, ## Definition of the grid, containing IDs according to the Phreeqc input script
|
||||||
grid_size = c(s_rows, s_cols), # Size of the grid in meters
|
grid_size = c(s_rows, s_cols), ## Size of the grid in meters
|
||||||
constant_cells = c() # IDs of cells with constant concentration
|
constant_cells = c() ## IDs of cells with constant concentration
|
||||||
)
|
)
|
||||||
|
|
||||||
bound_length <- 2
|
bound_length <- 2
|
||||||
@ -37,14 +37,14 @@ diffusion_setup <- list(
|
|||||||
|
|
||||||
dht_species <- c(
|
dht_species <- c(
|
||||||
"H" = 4,
|
"H" = 4,
|
||||||
"O" = 9,
|
"O" = 10,
|
||||||
"Charge" = 4,
|
"Charge" = 4,
|
||||||
"Ba" = 4,
|
"Ba" = 7,
|
||||||
"Cl" = 4,
|
"Cl" = 4,
|
||||||
"S(6)" = 7,
|
"S(6)" = 7,
|
||||||
"Sr" = 4,
|
"Sr" = 4,
|
||||||
"Barite" = 7,
|
"Barite" = 2,
|
||||||
"Celestite" = 7
|
"Celestite" = 2
|
||||||
)
|
)
|
||||||
|
|
||||||
chemistry_setup <- list(
|
chemistry_setup <- list(
|
||||||
|
|||||||
BIN
bench/barite/barite_50ai.rds
Normal file
BIN
bench/barite/barite_50ai.rds
Normal file
Binary file not shown.
@ -1,4 +1,4 @@
|
|||||||
## Time-stamp: "Last modified 2024-05-30 11:16:57 delucia"
|
## Time-stamp: "Last modified 2024-05-30 13:27:06 delucia"
|
||||||
|
|
||||||
## load a pretrained model from tensorflow file
|
## load a pretrained model from tensorflow file
|
||||||
## Use the global variable "ai_surrogate_base_path" when using file paths
|
## Use the global variable "ai_surrogate_base_path" when using file paths
|
||||||
@ -74,12 +74,17 @@ validate_predictions <- function(predictors, prediction) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
training_step <- function(model, predictor, target, validity) {
|
training_step <- function(model, predictor, target, validity) {
|
||||||
msgm("Training:")
|
msgm("Starting incremental training:")
|
||||||
|
|
||||||
x <- as.matrix(predictor)
|
## x <- as.matrix(predictor)
|
||||||
y <- as.matrix(target[colnames(x)])
|
## y <- as.matrix(target[colnames(x)])
|
||||||
|
|
||||||
model %>% keras3::fit(x, y)
|
history <- model %>% keras3::fit(x = data.matrix(predictor),
|
||||||
|
y = data.matrix(target),
|
||||||
|
epochs = 10, verbose=1)
|
||||||
|
|
||||||
model %>% keras3::save_model(paste0(out_dir, "/current_model.keras"), overwrite=TRUE)
|
keras3::save_model(model,
|
||||||
|
filepath = paste0(out_dir, "/current_model.keras"),
|
||||||
|
overwrite=TRUE)
|
||||||
|
return(model)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -336,7 +336,8 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
|||||||
// TODO: Check how to get the correct columns
|
// TODO: Check how to get the correct columns
|
||||||
R.parseEval("target_scaled <- preprocess(targets)");
|
R.parseEval("target_scaled <- preprocess(targets)");
|
||||||
|
|
||||||
R.parseEval("training_step(model, predictors_scaled, target_scaled, validity_vector)");
|
MSG("AI: incremental training");
|
||||||
|
R.parseEval("model <- training_step(model, predictors_scaled, target_scaled, validity_vector)");
|
||||||
double ai_end_t = MPI_Wtime();
|
double ai_end_t = MPI_Wtime();
|
||||||
R["ai_training_time"] = ai_end_t - ai_start_t;
|
R["ai_training_time"] = ai_end_t - ai_start_t;
|
||||||
}
|
}
|
||||||
@ -477,8 +478,11 @@ int main(int argc, char *argv[]) {
|
|||||||
|
|
||||||
if (!ai_surrogate_input_script_path.empty()) {
|
if (!ai_surrogate_input_script_path.empty()) {
|
||||||
R["ai_surrogate_base_path"] = ai_surrogate_input_script_path.substr(0, ai_surrogate_input_script_path.find_last_of('/') + 1);
|
R["ai_surrogate_base_path"] = ai_surrogate_input_script_path.substr(0, ai_surrogate_input_script_path.find_last_of('/') + 1);
|
||||||
|
|
||||||
|
MSG("AI: sourcing user-provided script");
|
||||||
R.parseEvalQ("source('" + ai_surrogate_input_script_path + "')");
|
R.parseEvalQ("source('" + ai_surrogate_input_script_path + "')");
|
||||||
}
|
}
|
||||||
|
MSG("AI: initialize AI model");
|
||||||
R.parseEval("model <- initiate_model()");
|
R.parseEval("model <- initiate_model()");
|
||||||
R.parseEval("gpu_info()");
|
R.parseEval("gpu_info()");
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user