mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-15 12:28:22 +01:00
feat: Add AI Surrogate functions to V.03
This commit is contained in:
parent
f5f2cb4b9c
commit
95cb95998e
71
R_lib/ai_surrogate_model.R
Normal file
71
R_lib/ai_surrogate_model.R
Normal file
@ -0,0 +1,71 @@
|
||||
## This file contains default function implementations for the ai surrogate.
|
||||
## To load pretrained models, use pre-/postprocessing or change hyperparameters
|
||||
## it is recommended to override these functions with custom implementations via
|
||||
## the input script. The path to the R-file containing the functions mus be set
|
||||
## in the variable "ai_surrogate_input_script". See the barite_200.R file as an
|
||||
## example and the general README for more information.
|
||||
|
||||
library(keras)
|
||||
library(tensorflow)
|
||||
|
||||
initiate_model <- function() {
|
||||
hidden_layers <- c(48, 96, 24)
|
||||
activation <- "relu"
|
||||
loss <- "mean_squared_error"
|
||||
|
||||
input_length <- length(ai_surrogate_species)
|
||||
output_length <- length(ai_surrogate_species)
|
||||
## Creates a new sequential model from scratch
|
||||
model <- keras_model_sequential()
|
||||
|
||||
## Input layer defined by input data shape
|
||||
model %>% layer_dense(units = input_length,
|
||||
activation = activation,
|
||||
input_shape = input_length,
|
||||
dtype = "float32")
|
||||
|
||||
for (layer_size in hidden_layers) {
|
||||
model %>% layer_dense(units = layer_size,
|
||||
activation = activation,
|
||||
dtype = "float32")
|
||||
}
|
||||
|
||||
## Output data defined by output data shape
|
||||
model %>% layer_dense(units = output_length,
|
||||
activation = activation,
|
||||
dtype = "float32")
|
||||
|
||||
model %>% compile(loss = loss,
|
||||
optimizer = "adam")
|
||||
return(model)
|
||||
}
|
||||
|
||||
gpu_info <- function() {
|
||||
msgm(tf_gpu_configured())
|
||||
}
|
||||
|
||||
prediction_step <- function(model, predictors) {
|
||||
prediction <- predict(model, as.matrix(predictors))
|
||||
colnames(prediction) <- colnames(predictors)
|
||||
return(as.data.frame(prediction))
|
||||
}
|
||||
|
||||
preprocess <- function(df, backtransform = FALSE, outputs = FALSE) {
|
||||
return(df)
|
||||
}
|
||||
|
||||
set_valid_predictions <- function(temp_field, prediction, validity) {
|
||||
temp_field[validity == 1, ] <- prediction[validity == 1, ]
|
||||
return(temp_field)
|
||||
}
|
||||
|
||||
training_step <- function(model, predictor, target, validity) {
|
||||
msgm("Training:")
|
||||
|
||||
x <- as.matrix(predictor)
|
||||
y <- as.matrix(target[colnames(x)])
|
||||
|
||||
model %>% fit(x, y)
|
||||
|
||||
model %>% save_model_tf(paste0(out_dir, "/current_model.keras"))
|
||||
}
|
||||
@ -53,4 +53,4 @@ add_missing_transport_species <- function(init_grid, new_names) {
|
||||
new_grid <- cbind(new_grid, append_df)
|
||||
|
||||
return(new_grid)
|
||||
}
|
||||
}
|
||||
|
||||
@ -70,14 +70,19 @@ master_iteration_end <- function(setup, state_T, state_C) {
|
||||
if (setup$store_result) {
|
||||
if (iter %in% setup$out_save) {
|
||||
nameout <- paste0(setup$out_dir, "/iter_", sprintf(fmt = fmt, iter), ".rds")
|
||||
|
||||
state_T <- data.frame(state_T, check.names = FALSE)
|
||||
state_C <- data.frame(state_C, check.names = FALSE)
|
||||
|
||||
|
||||
ai_surrogate_info <- list(
|
||||
prediction_time = if(exists("ai_prediction_time")) as.integer(ai_prediction_time) else NULL,
|
||||
training_time = if(exists("ai_training_time")) as.integer(ai_training_time) else NULL,
|
||||
valid_predictions = if(exists("validity_vector")) validity_vector else NULL)
|
||||
saveRDS(list(
|
||||
T = state_T,
|
||||
C = state_C,
|
||||
simtime = as.integer(setup$simulation_time)
|
||||
simtime = as.integer(setup$simulation_time),
|
||||
totaltime = as.integer(totaltime),
|
||||
ai_surrogate_info = ai_surrogate_info
|
||||
), file = nameout)
|
||||
msgm("results stored in <", nameout, ">")
|
||||
}
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
function(ADD_BENCH_TARGET TARGET POET_BENCH_LIST RT_FILES OUT_PATH)
|
||||
set(bench_install_dir share/poet/${OUT_PATH})
|
||||
|
||||
@ -41,4 +40,3 @@ add_custom_target(${BENCHTARGET} ALL)
|
||||
add_subdirectory(barite)
|
||||
add_subdirectory(dolo)
|
||||
add_subdirectory(surfex)
|
||||
|
||||
|
||||
@ -47,7 +47,8 @@ dht_species <- c(
|
||||
)
|
||||
|
||||
chemistry_setup <- list(
|
||||
dht_species = dht_species
|
||||
dht_species = dht_species,
|
||||
ai_surrogate_input_script = "./barite_200ai_surrogate_input_script.R"
|
||||
)
|
||||
|
||||
# Define a setup list for simulation configuration
|
||||
|
||||
48
bench/barite/barite_200ai_surrogate_input_script.R
Normal file
48
bench/barite/barite_200ai_surrogate_input_script.R
Normal file
@ -0,0 +1,48 @@
|
||||
## load a pretrained model from tensorflow file
|
||||
## Use the global variable "ai_surrogate_base_path" when using file paths
|
||||
## relative to the input script
|
||||
initiate_model <- function() {
|
||||
init_model <- normalizePath(paste0(ai_surrogate_base_path,
|
||||
"model_min_max_float64.keras"))
|
||||
return(load_model_tf(init_model))
|
||||
}
|
||||
|
||||
scale_min_max <- function(x, min, max, backtransform) {
|
||||
if (backtransform) {
|
||||
return((x * (max - min)) + min)
|
||||
} else {
|
||||
return((x - min) / (max - min))
|
||||
}
|
||||
}
|
||||
|
||||
preprocess <- function(df, backtransform = FALSE, outputs = FALSE) {
|
||||
minmax_file <- normalizePath(paste0(ai_surrogate_base_path,
|
||||
"min_max_bounds.rds"))
|
||||
global_minmax <- readRDS(minmax_file)
|
||||
for (column in colnames(df)) {
|
||||
df[column] <- lapply(df[column],
|
||||
scale_min_max,
|
||||
global_minmax$min[column],
|
||||
global_minmax$max[column],
|
||||
backtransform)
|
||||
}
|
||||
return(df)
|
||||
}
|
||||
|
||||
mass_balance <- function(predictors, prediction) {
|
||||
dBa <- abs(prediction$Ba + prediction$Barite -
|
||||
predictors$Ba - predictors$Barite)
|
||||
dSr <- abs(prediction$Sr + prediction$Celestite -
|
||||
predictors$Sr - predictors$Celestite)
|
||||
return(dBa + dSr)
|
||||
}
|
||||
|
||||
validate_predictions <- function(predictors, prediction) {
|
||||
epsilon <- 0.00003
|
||||
mb <- mass_balance(predictors, prediction)
|
||||
msgm("Mass balance mean:", mean(mb))
|
||||
msgm("Mass balance variance:", var(mb))
|
||||
msgm("Rows where mass balance meets threshold", epsilon, ":",
|
||||
sum(mb < epsilon))
|
||||
return(mb < epsilon)
|
||||
}
|
||||
BIN
bench/barite/min_max_bounds.rds
Normal file
BIN
bench/barite/min_max_bounds.rds
Normal file
Binary file not shown.
BIN
bench/barite/model_min_max_float64.keras
Normal file
BIN
bench/barite/model_min_max_float64.keras
Normal file
Binary file not shown.
@ -70,6 +70,7 @@ target_compile_definitions(POETLib PUBLIC STRICT_R_HEADERS OMPI_SKIP_MPICXX)
|
||||
|
||||
file(READ "${PROJECT_SOURCE_DIR}/R_lib/kin_r_library.R" R_KIN_LIB )
|
||||
file(READ "${PROJECT_SOURCE_DIR}/R_lib/init_r_lib.R" R_INIT_LIB)
|
||||
file(READ "${PROJECT_SOURCE_DIR}/R_lib/ai_surrogate_model.R" R_AI_SURROGATE_LIB)
|
||||
|
||||
configure_file(poet.hpp.in poet.hpp @ONLY)
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
namespace poet {
|
||||
|
||||
enum DHT_PROP_TYPES { DHT_TYPE_DEFAULT, DHT_TYPE_CHARGE, DHT_TYPE_TOTAL };
|
||||
enum CHEMISTRY_OUT_SOURCE { CHEM_PQC, CHEM_DHT, CHEM_INTERP };
|
||||
enum CHEMISTRY_OUT_SOURCE { CHEM_PQC, CHEM_DHT, CHEM_INTERP, CHEM_AISURR };
|
||||
|
||||
struct WorkPackage {
|
||||
std::size_t size;
|
||||
|
||||
@ -371,3 +371,7 @@ void poet::ChemistryModule::unshuffleField(const std::vector<double> &in_buffer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void poet::ChemistryModule::set_ai_surrogate_validity_vector(std::vector<int> r_vector) {
|
||||
this->ai_surrogate_validity_vector = r_vector;
|
||||
}
|
||||
|
||||
@ -83,6 +83,7 @@ public:
|
||||
std::uint32_t interp_bucket_size;
|
||||
std::uint32_t interp_size_mb;
|
||||
std::uint32_t interp_min_entries;
|
||||
bool ai_surrogate_enabled;
|
||||
};
|
||||
|
||||
void masterEnableSurrogates(const SurrogateSetup &setup) {
|
||||
@ -92,6 +93,7 @@ public:
|
||||
|
||||
this->dht_enabled = setup.dht_enabled;
|
||||
this->interp_enabled = setup.interp_enabled;
|
||||
this->ai_surrogate_enabled = setup.ai_surrogate_enabled;
|
||||
|
||||
if (this->dht_enabled || this->interp_enabled) {
|
||||
this->initializeDHT(setup.dht_size_mb, this->params.dht_species);
|
||||
@ -219,6 +221,11 @@ public:
|
||||
this->print_progessbar = enabled;
|
||||
};
|
||||
|
||||
/**
|
||||
* **Master only** Set the ai surrogate validity vector from R
|
||||
*/
|
||||
void set_ai_surrogate_validity_vector(std::vector<int> r_vector);
|
||||
|
||||
std::vector<uint32_t> GetWorkerInterpolationCalls() const;
|
||||
|
||||
std::vector<double> GetWorkerInterpolationWriteTimings() const;
|
||||
@ -228,6 +235,8 @@ public:
|
||||
|
||||
std::vector<uint32_t> GetWorkerPHTCacheHits() const;
|
||||
|
||||
std::vector<int> ai_surrogate_validity_vector;
|
||||
|
||||
protected:
|
||||
void initializeDHT(uint32_t size_mb,
|
||||
const NamedVector<std::uint32_t> &key_species);
|
||||
@ -249,7 +258,8 @@ protected:
|
||||
CHEM_IP_SIGNIF_VEC,
|
||||
CHEM_WORK_LOOP,
|
||||
CHEM_PERF,
|
||||
CHEM_BREAK_MAIN_LOOP
|
||||
CHEM_BREAK_MAIN_LOOP,
|
||||
CHEM_AI_BCAST_VALIDITY
|
||||
};
|
||||
|
||||
enum { LOOP_WORK, LOOP_END };
|
||||
@ -348,6 +358,8 @@ protected:
|
||||
bool interp_enabled{false};
|
||||
std::unique_ptr<poet::InterpolationModule> interp;
|
||||
|
||||
bool ai_surrogate_enabled{false};
|
||||
|
||||
static constexpr uint32_t BUFFER_OFFSET = 5;
|
||||
|
||||
inline void ChemBCast(void *buf, int count, MPI_Datatype datatype) const {
|
||||
|
||||
@ -159,6 +159,20 @@ std::vector<uint32_t> poet::ChemistryModule::GetWorkerPHTCacheHits() const {
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline std::vector<int> shuffleVector(const std::vector<int> &in_vector,
|
||||
uint32_t size_per_prop,
|
||||
uint32_t wp_count) {
|
||||
std::vector<int> out_buffer(in_vector.size());
|
||||
uint32_t write_i = 0;
|
||||
for (uint32_t i = 0; i < wp_count; i++) {
|
||||
for (uint32_t j = i; j < size_per_prop; j += wp_count) {
|
||||
out_buffer[write_i] = in_vector[j];
|
||||
write_i++;
|
||||
}
|
||||
}
|
||||
return out_buffer;
|
||||
}
|
||||
|
||||
inline std::vector<double> shuffleField(const std::vector<double> &in_field,
|
||||
uint32_t size_per_prop,
|
||||
uint32_t prop_count,
|
||||
@ -247,8 +261,10 @@ inline void poet::ChemistryModule::MasterSendPkgs(
|
||||
send_buffer[end_of_wp + 2] = dt;
|
||||
// current time of simulation (age) in seconds
|
||||
send_buffer[end_of_wp + 3] = this->simtime;
|
||||
// placeholder for work_package_count
|
||||
send_buffer[end_of_wp + 4] = 0.;
|
||||
// current work package start location in field
|
||||
uint32_t wp_start_index = std::accumulate(wp_sizes_vector.begin(), std::next(wp_sizes_vector.begin(), count_pkgs), 0);
|
||||
send_buffer[end_of_wp + 4] = wp_start_index;
|
||||
|
||||
|
||||
/* ATTENTION Worker p has rank p+1 */
|
||||
// MPI_Send(send_buffer, end_of_wp + BUFFER_OFFSET, MPI_DOUBLE, p + 1,
|
||||
@ -352,8 +368,21 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
||||
int pkg_to_send, pkg_to_recv;
|
||||
int free_workers;
|
||||
int i_pkgs;
|
||||
int ftype;
|
||||
|
||||
int ftype = CHEM_WORK_LOOP;
|
||||
const std::vector<uint32_t> wp_sizes_vector =
|
||||
CalculateWPSizesVector(this->n_cells, this->wp_size);
|
||||
|
||||
if (this->ai_surrogate_enabled) {
|
||||
ftype = CHEM_AI_BCAST_VALIDITY;
|
||||
PropagateFunctionType(ftype);
|
||||
this->ai_surrogate_validity_vector = shuffleVector(this->ai_surrogate_validity_vector,
|
||||
this->n_cells,
|
||||
wp_sizes_vector.size());
|
||||
ChemBCast(&this->ai_surrogate_validity_vector.front(), this->n_cells, MPI_INT);
|
||||
}
|
||||
|
||||
ftype = CHEM_WORK_LOOP;
|
||||
PropagateFunctionType(ftype);
|
||||
|
||||
MPI_Barrier(this->group_comm);
|
||||
@ -363,9 +392,6 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
||||
/* start time measurement of sequential part */
|
||||
seq_a = MPI_Wtime();
|
||||
|
||||
const std::vector<uint32_t> wp_sizes_vector =
|
||||
CalculateWPSizesVector(this->n_cells, this->wp_size);
|
||||
|
||||
/* shuffle grid */
|
||||
// grid.shuffleAndExport(mpi_buffer);
|
||||
std::vector<double> mpi_buffer =
|
||||
|
||||
@ -47,6 +47,14 @@ void poet::ChemistryModule::WorkerLoop() {
|
||||
switch (func_type) {
|
||||
case CHEM_FIELD_INIT: {
|
||||
ChemBCast(&this->prop_count, 1, MPI_UINT32_T);
|
||||
if (this->ai_surrogate_enabled) {
|
||||
this->ai_surrogate_validity_vector.reserve(this->n_cells);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case CHEM_AI_BCAST_VALIDITY: {
|
||||
// Receive the index vector of valid ai surrogate predictions
|
||||
MPI_Bcast(&this->ai_surrogate_validity_vector.front(), this->n_cells, MPI_INT, 0, this->group_comm);
|
||||
break;
|
||||
}
|
||||
case CHEM_WORK_LOOP: {
|
||||
@ -118,7 +126,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
uint32_t iteration;
|
||||
double dt;
|
||||
double current_sim_time;
|
||||
|
||||
uint32_t wp_start_index;
|
||||
int count = double_count;
|
||||
std::vector<double> mpi_buffer(count);
|
||||
|
||||
@ -170,6 +178,16 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
interp->tryInterpolation(s_curr_wp);
|
||||
}
|
||||
|
||||
if (this->ai_surrogate_enabled) {
|
||||
// Map valid predictions from the ai surrogate in the workpackage
|
||||
for (int i = 0; i < s_curr_wp.size; i++) {
|
||||
if (this->ai_surrogate_validity_vector[wp_start_index + i] == 1) {
|
||||
s_curr_wp.mapping[i] = CHEM_AISURR;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
phreeqc_time_start = MPI_Wtime();
|
||||
|
||||
WorkerRunWorkPackage(s_curr_wp, current_sim_time, dt);
|
||||
|
||||
@ -32,6 +32,11 @@ void InitialList::initChemistry(const Rcpp::List &chem) {
|
||||
}
|
||||
}
|
||||
|
||||
if (chem.containsElementNamed("ai_surrogate_input_script")) {
|
||||
std::string ai_surrogate_input_script_path = chem["ai_surrogate_input_script"];
|
||||
this->ai_surrogate_input_script = Rcpp::as<std::string>(Rcpp::Function("normalizePath")(Rcpp::wrap(ai_surrogate_input_script_path)));
|
||||
}
|
||||
|
||||
this->field_header =
|
||||
Rcpp::as<std::vector<std::string>>(this->initial_grid.names());
|
||||
this->field_header.erase(this->field_header.begin());
|
||||
@ -65,6 +70,7 @@ InitialList::ChemistryInit InitialList::getChemistryInit() const {
|
||||
|
||||
chem_init.dht_species = dht_species;
|
||||
chem_init.interp_species = interp_species;
|
||||
chem_init.ai_surrogate_input_script = ai_surrogate_input_script;
|
||||
|
||||
if (this->chem_hooks.size() > 0) {
|
||||
if (this->chem_hooks.containsElementNamed("dht_fill")) {
|
||||
|
||||
@ -82,6 +82,8 @@ void InitialList::importList(const Rcpp::List &setup, bool minimal) {
|
||||
|
||||
this->chem_hooks =
|
||||
Rcpp::as<Rcpp::List>(setup[static_cast<int>(ExportList::CHEM_HOOKS)]);
|
||||
|
||||
this->ai_surrogate_input_script = Rcpp::as<std::string>(setup[static_cast<int>(ExportList::AI_SURROGATE_INPUT_SCRIPT)]);
|
||||
}
|
||||
|
||||
Rcpp::List InitialList::exportList() {
|
||||
@ -129,6 +131,7 @@ Rcpp::List InitialList::exportList() {
|
||||
out[static_cast<int>(ExportList::CHEM_INTERP_SPECIES)] =
|
||||
Rcpp::wrap(this->interp_species);
|
||||
out[static_cast<int>(ExportList::CHEM_HOOKS)] = this->chem_hooks;
|
||||
out[static_cast<int>(ExportList::AI_SURROGATE_INPUT_SCRIPT)] = this->ai_surrogate_input_script;
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -35,7 +35,7 @@ public:
|
||||
void importList(const Rcpp::List &setup, bool minimal = false);
|
||||
Rcpp::List exportList();
|
||||
|
||||
Field getInitialGrid() const { return Field(this->initial_grid); }
|
||||
Field getInitialGrid() const { return Field(this->initial_grid); }
|
||||
|
||||
private:
|
||||
RInside &R;
|
||||
@ -66,8 +66,9 @@ private:
|
||||
CHEM_DHT_SPECIES,
|
||||
CHEM_INTERP_SPECIES,
|
||||
CHEM_HOOKS,
|
||||
ENUM_SIZE
|
||||
};
|
||||
AI_SURROGATE_INPUT_SCRIPT,
|
||||
ENUM_SIZE // Hack: Last element of the enum to show enum size
|
||||
};
|
||||
|
||||
// Grid members
|
||||
static constexpr const char *grid_key = "Grid";
|
||||
@ -203,6 +204,9 @@ private:
|
||||
NamedVector<std::uint32_t> dht_species;
|
||||
|
||||
NamedVector<std::uint32_t> interp_species;
|
||||
|
||||
// Path to R script that the user defines in the input file
|
||||
std::string ai_surrogate_input_script;
|
||||
|
||||
Rcpp::List chem_hooks;
|
||||
|
||||
@ -233,6 +237,8 @@ public:
|
||||
NamedVector<std::uint32_t> dht_species;
|
||||
NamedVector<std::uint32_t> interp_species;
|
||||
ChemistryHookFunctions hooks;
|
||||
|
||||
std::string ai_surrogate_input_script;
|
||||
};
|
||||
|
||||
ChemistryInit getChemistryInit() const;
|
||||
|
||||
95
src/poet.cpp
95
src/poet.cpp
@ -146,10 +146,13 @@ ParseRet parseInitValues(char **argv, RuntimeParameters ¶ms) {
|
||||
cmdl("interp-min", 5) >> params.interp_min_entries;
|
||||
cmdl("interp-bucket-entries", 20) >> params.interp_bucket_entries;
|
||||
|
||||
params.use_ai_surrogate = cmdl["ai-surrogate"];
|
||||
|
||||
if (MY_RANK == 0) {
|
||||
// MSG("Complete results storage is " + BOOL_PRINT(simparams.store_result));
|
||||
MSG("Work Package Size: " + std::to_string(params.work_package_size));
|
||||
MSG("DHT is " + BOOL_PRINT(params.use_dht));
|
||||
MSG("AI Surrogate is " + BOOL_PRINT(params.use_ai_surrogate));
|
||||
|
||||
if (params.use_dht) {
|
||||
// MSG("DHT strategy is " + std::to_string(simparams.dht_strategy));
|
||||
@ -253,9 +256,9 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
if (params.print_progressbar) {
|
||||
chem.setProgressBarPrintout(true);
|
||||
}
|
||||
|
||||
R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps());
|
||||
|
||||
/* SIMULATION LOOP */
|
||||
|
||||
double dSimTime{0};
|
||||
for (uint32_t iter = 1; iter < maxiter + 1; iter++) {
|
||||
double start_t = MPI_Wtime();
|
||||
@ -273,12 +276,71 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
|
||||
chem.getField().update(diffusion.getField());
|
||||
|
||||
// chem.getfield().update(diffusion.getfield());
|
||||
|
||||
MSG("Chemistry step");
|
||||
if (params.use_ai_surrogate) {
|
||||
double ai_start_t = MPI_Wtime();
|
||||
// Save current values from the tug field as predictor for the ai step
|
||||
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
|
||||
R.parseEval(std::string(
|
||||
"predictors <- setNames(data.frame(matrix(TMP, nrow=" +
|
||||
std::to_string(chem.getField().GetRequestedVecSize()) + ")), TMP_PROPS)"));
|
||||
R.parseEval("predictors <- predictors[ai_surrogate_species]");
|
||||
|
||||
// Predict
|
||||
R.parseEval("predictors_scaled <- preprocess(predictors)");
|
||||
|
||||
R.parseEval("print('PREDICTORS:')");
|
||||
R.parseEval("print(head(predictors))");
|
||||
|
||||
R.parseEval("prediction <- preprocess(prediction_step(model, predictors_scaled),\
|
||||
backtransform = TRUE,\
|
||||
outputs = TRUE)");
|
||||
|
||||
// Validate prediction and write valid predictions to chem field
|
||||
R.parseEval("validity_vector <- validate_predictions(predictors,\
|
||||
prediction)");
|
||||
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
|
||||
|
||||
std::vector<std::vector<double>> RTempField = R.parseEval("set_valid_predictions(predictors,\
|
||||
prediction,\
|
||||
validity_vector)");
|
||||
|
||||
Field predictions_field = Field(R.parseEval("nrow(predictors)"),
|
||||
RTempField,
|
||||
R.parseEval("names(predictors)"));
|
||||
chem.getField().update(predictions_field);
|
||||
double ai_end_t = MPI_Wtime();
|
||||
R["ai_prediction_time"] = ai_end_t - ai_start_t;
|
||||
}
|
||||
|
||||
chem.simulate(dt);
|
||||
|
||||
/* AI surrogate iterative training*/
|
||||
if (params.use_ai_surrogate) {
|
||||
double ai_start_t = MPI_Wtime();
|
||||
|
||||
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
|
||||
R.parseEval(std::string(
|
||||
"targets <- setNames(data.frame(matrix(TMP, nrow=" +
|
||||
std::to_string(chem.getField().GetRequestedVecSize()) + ")), TMP_PROPS)"));
|
||||
R.parseEval("targets <- targets[ai_surrogate_species]");
|
||||
|
||||
// TODO: Check how to get the correct columns
|
||||
R.parseEval("target_scaled <- preprocess(targets, outputs = TRUE)");
|
||||
|
||||
R.parseEval("print('TARGET:')");
|
||||
R.parseEval("print(head(target_scaled))");
|
||||
|
||||
R.parseEval("training_step(model, predictors_scaled, target_scaled, validity_vector)");
|
||||
double ai_end_t = MPI_Wtime();
|
||||
R["ai_training_time"] = ai_end_t - ai_start_t;
|
||||
}
|
||||
|
||||
// MPI_Barrier(MPI_COMM_WORLD);
|
||||
double end_t = MPI_Wtime();
|
||||
dSimTime += end_t - start_t;
|
||||
R["totaltime"] = dSimTime;
|
||||
|
||||
// MDL master_iteration_end just writes on disk state_T and
|
||||
// state_C after every iteration if the cmdline option
|
||||
// --ignore-results is not given (and thus the R variable
|
||||
@ -290,10 +352,6 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +
|
||||
std::to_string(maxiter));
|
||||
MSG();
|
||||
|
||||
// MPI_Barrier(MPI_COMM_WORLD);
|
||||
double end_t = MPI_Wtime();
|
||||
dSimTime += end_t - start_t;
|
||||
} // END SIMULATION LOOP
|
||||
|
||||
Rcpp::List chem_profiling;
|
||||
@ -384,16 +442,15 @@ int main(int argc, char *argv[]) {
|
||||
run_params.use_interp,
|
||||
run_params.interp_bucket_entries,
|
||||
run_params.interp_size,
|
||||
run_params.interp_min_entries};
|
||||
run_params.interp_min_entries,
|
||||
run_params.use_ai_surrogate};
|
||||
|
||||
chemistry.masterEnableSurrogates(surr_setup);
|
||||
|
||||
if (MY_RANK > 0) {
|
||||
chemistry.WorkerLoop();
|
||||
} else {
|
||||
|
||||
init_global_functions(R);
|
||||
|
||||
// R.parseEvalQ("mysetup <- setup");
|
||||
// // if (MY_RANK == 0) { // get timestep vector from
|
||||
// // grid_init function ... //
|
||||
@ -404,6 +461,22 @@ int main(int argc, char *argv[]) {
|
||||
// MDL: store all parameters
|
||||
// MSG("Calling R Function to store calling parameters");
|
||||
// R.parseEvalQ("StoreSetup(setup=mysetup)");
|
||||
if (run_params.use_ai_surrogate) {
|
||||
/* Incorporate ai surrogate from R */
|
||||
R.parseEvalQ(ai_surrogate_r_library);
|
||||
/* Use dht species for model input and output */
|
||||
R["ai_surrogate_species"] = init_list.getChemistryInit().dht_species.getNames();
|
||||
R["out_dir"] = run_params.out_dir;
|
||||
|
||||
const std::string ai_surrogate_input_script_path = init_list.getChemistryInit().ai_surrogate_input_script;
|
||||
|
||||
if (!ai_surrogate_input_script_path.empty()) {
|
||||
R["ai_surrogate_base_path"] = ai_surrogate_input_script_path.substr(0, ai_surrogate_input_script_path.find_last_of('/') + 1);
|
||||
R.parseEvalQ("source('" + ai_surrogate_input_script_path + "')");
|
||||
}
|
||||
R.parseEval("model <- initiate_model()");
|
||||
R.parseEval("gpu_info()");
|
||||
}
|
||||
|
||||
MSG("Init done on process with rank " + std::to_string(MY_RANK));
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@
|
||||
** Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <set>
|
||||
@ -35,11 +35,11 @@ static const char *poet_version = "@POET_VERSION@";
|
||||
static const inline std::string kin_r_library = R"(@R_KIN_LIB@)";
|
||||
|
||||
static const inline std::string init_r_library = R"(@R_INIT_LIB@)";
|
||||
|
||||
static const inline std::string ai_surrogate_r_library = R"(@R_AI_SURROGATE_LIB@)";
|
||||
static const inline std::string r_runtime_parameters = "mysetup";
|
||||
|
||||
const std::set<std::string> flaglist{"ignore-result", "dht", "P", "progress",
|
||||
"interp"};
|
||||
"interp", "ai-surrogate"};
|
||||
const std::set<std::string> paramlist{
|
||||
"work-package-size", "dht-strategy", "dht-size", "dht-snaps",
|
||||
"dht-file", "interp-size", "interp-min", "interp-bucket-entries"};
|
||||
@ -66,6 +66,7 @@ struct RuntimeParameters {
|
||||
std::uint32_t interp_min_entries;
|
||||
std::uint32_t interp_bucket_entries;
|
||||
|
||||
bool use_ai_surrogate;
|
||||
struct ChemistryParams {
|
||||
// std::string database_path;
|
||||
// std::string input_script;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user