feat: Add AI Surrogate functions to V.03

This commit is contained in:
hans 2024-05-27 09:09:01 +02:00
parent f5f2cb4b9c
commit 95cb95998e
19 changed files with 306 additions and 33 deletions

View File

@ -0,0 +1,71 @@
## This file contains default function implementations for the ai surrogate.
## To load pretrained models, use pre-/postprocessing or change hyperparameters
## it is recommended to override these functions with custom implementations via
## the input script. The path to the R-file containing the functions mus be set
## in the variable "ai_surrogate_input_script". See the barite_200.R file as an
## example and the general README for more information.
library(keras)
library(tensorflow)
initiate_model <- function() {
hidden_layers <- c(48, 96, 24)
activation <- "relu"
loss <- "mean_squared_error"
input_length <- length(ai_surrogate_species)
output_length <- length(ai_surrogate_species)
## Creates a new sequential model from scratch
model <- keras_model_sequential()
## Input layer defined by input data shape
model %>% layer_dense(units = input_length,
activation = activation,
input_shape = input_length,
dtype = "float32")
for (layer_size in hidden_layers) {
model %>% layer_dense(units = layer_size,
activation = activation,
dtype = "float32")
}
## Output data defined by output data shape
model %>% layer_dense(units = output_length,
activation = activation,
dtype = "float32")
model %>% compile(loss = loss,
optimizer = "adam")
return(model)
}
gpu_info <- function() {
msgm(tf_gpu_configured())
}
prediction_step <- function(model, predictors) {
prediction <- predict(model, as.matrix(predictors))
colnames(prediction) <- colnames(predictors)
return(as.data.frame(prediction))
}
preprocess <- function(df, backtransform = FALSE, outputs = FALSE) {
return(df)
}
set_valid_predictions <- function(temp_field, prediction, validity) {
temp_field[validity == 1, ] <- prediction[validity == 1, ]
return(temp_field)
}
training_step <- function(model, predictor, target, validity) {
msgm("Training:")
x <- as.matrix(predictor)
y <- as.matrix(target[colnames(x)])
model %>% fit(x, y)
model %>% save_model_tf(paste0(out_dir, "/current_model.keras"))
}

View File

@ -53,4 +53,4 @@ add_missing_transport_species <- function(init_grid, new_names) {
new_grid <- cbind(new_grid, append_df)
return(new_grid)
}
}

View File

@ -70,14 +70,19 @@ master_iteration_end <- function(setup, state_T, state_C) {
if (setup$store_result) {
if (iter %in% setup$out_save) {
nameout <- paste0(setup$out_dir, "/iter_", sprintf(fmt = fmt, iter), ".rds")
state_T <- data.frame(state_T, check.names = FALSE)
state_C <- data.frame(state_C, check.names = FALSE)
ai_surrogate_info <- list(
prediction_time = if(exists("ai_prediction_time")) as.integer(ai_prediction_time) else NULL,
training_time = if(exists("ai_training_time")) as.integer(ai_training_time) else NULL,
valid_predictions = if(exists("validity_vector")) validity_vector else NULL)
saveRDS(list(
T = state_T,
C = state_C,
simtime = as.integer(setup$simulation_time)
simtime = as.integer(setup$simulation_time),
totaltime = as.integer(totaltime),
ai_surrogate_info = ai_surrogate_info
), file = nameout)
msgm("results stored in <", nameout, ">")
}

View File

@ -1,4 +1,3 @@
function(ADD_BENCH_TARGET TARGET POET_BENCH_LIST RT_FILES OUT_PATH)
set(bench_install_dir share/poet/${OUT_PATH})
@ -41,4 +40,3 @@ add_custom_target(${BENCHTARGET} ALL)
add_subdirectory(barite)
add_subdirectory(dolo)
add_subdirectory(surfex)

View File

@ -47,7 +47,8 @@ dht_species <- c(
)
chemistry_setup <- list(
dht_species = dht_species
dht_species = dht_species,
ai_surrogate_input_script = "./barite_200ai_surrogate_input_script.R"
)
# Define a setup list for simulation configuration

View File

@ -0,0 +1,48 @@
## load a pretrained model from tensorflow file
## Use the global variable "ai_surrogate_base_path" when using file paths
## relative to the input script
initiate_model <- function() {
init_model <- normalizePath(paste0(ai_surrogate_base_path,
"model_min_max_float64.keras"))
return(load_model_tf(init_model))
}
scale_min_max <- function(x, min, max, backtransform) {
if (backtransform) {
return((x * (max - min)) + min)
} else {
return((x - min) / (max - min))
}
}
preprocess <- function(df, backtransform = FALSE, outputs = FALSE) {
minmax_file <- normalizePath(paste0(ai_surrogate_base_path,
"min_max_bounds.rds"))
global_minmax <- readRDS(minmax_file)
for (column in colnames(df)) {
df[column] <- lapply(df[column],
scale_min_max,
global_minmax$min[column],
global_minmax$max[column],
backtransform)
}
return(df)
}
mass_balance <- function(predictors, prediction) {
dBa <- abs(prediction$Ba + prediction$Barite -
predictors$Ba - predictors$Barite)
dSr <- abs(prediction$Sr + prediction$Celestite -
predictors$Sr - predictors$Celestite)
return(dBa + dSr)
}
validate_predictions <- function(predictors, prediction) {
epsilon <- 0.00003
mb <- mass_balance(predictors, prediction)
msgm("Mass balance mean:", mean(mb))
msgm("Mass balance variance:", var(mb))
msgm("Rows where mass balance meets threshold", epsilon, ":",
sum(mb < epsilon))
return(mb < epsilon)
}

Binary file not shown.

Binary file not shown.

View File

@ -70,6 +70,7 @@ target_compile_definitions(POETLib PUBLIC STRICT_R_HEADERS OMPI_SKIP_MPICXX)
file(READ "${PROJECT_SOURCE_DIR}/R_lib/kin_r_library.R" R_KIN_LIB )
file(READ "${PROJECT_SOURCE_DIR}/R_lib/init_r_lib.R" R_INIT_LIB)
file(READ "${PROJECT_SOURCE_DIR}/R_lib/ai_surrogate_model.R" R_AI_SURROGATE_LIB)
configure_file(poet.hpp.in poet.hpp @ONLY)

View File

@ -6,7 +6,7 @@
namespace poet {
enum DHT_PROP_TYPES { DHT_TYPE_DEFAULT, DHT_TYPE_CHARGE, DHT_TYPE_TOTAL };
enum CHEMISTRY_OUT_SOURCE { CHEM_PQC, CHEM_DHT, CHEM_INTERP };
enum CHEMISTRY_OUT_SOURCE { CHEM_PQC, CHEM_DHT, CHEM_INTERP, CHEM_AISURR };
struct WorkPackage {
std::size_t size;

View File

@ -371,3 +371,7 @@ void poet::ChemistryModule::unshuffleField(const std::vector<double> &in_buffer,
}
}
}
void poet::ChemistryModule::set_ai_surrogate_validity_vector(std::vector<int> r_vector) {
this->ai_surrogate_validity_vector = r_vector;
}

View File

@ -83,6 +83,7 @@ public:
std::uint32_t interp_bucket_size;
std::uint32_t interp_size_mb;
std::uint32_t interp_min_entries;
bool ai_surrogate_enabled;
};
void masterEnableSurrogates(const SurrogateSetup &setup) {
@ -92,6 +93,7 @@ public:
this->dht_enabled = setup.dht_enabled;
this->interp_enabled = setup.interp_enabled;
this->ai_surrogate_enabled = setup.ai_surrogate_enabled;
if (this->dht_enabled || this->interp_enabled) {
this->initializeDHT(setup.dht_size_mb, this->params.dht_species);
@ -219,6 +221,11 @@ public:
this->print_progessbar = enabled;
};
/**
* **Master only** Set the ai surrogate validity vector from R
*/
void set_ai_surrogate_validity_vector(std::vector<int> r_vector);
std::vector<uint32_t> GetWorkerInterpolationCalls() const;
std::vector<double> GetWorkerInterpolationWriteTimings() const;
@ -228,6 +235,8 @@ public:
std::vector<uint32_t> GetWorkerPHTCacheHits() const;
std::vector<int> ai_surrogate_validity_vector;
protected:
void initializeDHT(uint32_t size_mb,
const NamedVector<std::uint32_t> &key_species);
@ -249,7 +258,8 @@ protected:
CHEM_IP_SIGNIF_VEC,
CHEM_WORK_LOOP,
CHEM_PERF,
CHEM_BREAK_MAIN_LOOP
CHEM_BREAK_MAIN_LOOP,
CHEM_AI_BCAST_VALIDITY
};
enum { LOOP_WORK, LOOP_END };
@ -348,6 +358,8 @@ protected:
bool interp_enabled{false};
std::unique_ptr<poet::InterpolationModule> interp;
bool ai_surrogate_enabled{false};
static constexpr uint32_t BUFFER_OFFSET = 5;
inline void ChemBCast(void *buf, int count, MPI_Datatype datatype) const {

View File

@ -159,6 +159,20 @@ std::vector<uint32_t> poet::ChemistryModule::GetWorkerPHTCacheHits() const {
return ret;
}
inline std::vector<int> shuffleVector(const std::vector<int> &in_vector,
uint32_t size_per_prop,
uint32_t wp_count) {
std::vector<int> out_buffer(in_vector.size());
uint32_t write_i = 0;
for (uint32_t i = 0; i < wp_count; i++) {
for (uint32_t j = i; j < size_per_prop; j += wp_count) {
out_buffer[write_i] = in_vector[j];
write_i++;
}
}
return out_buffer;
}
inline std::vector<double> shuffleField(const std::vector<double> &in_field,
uint32_t size_per_prop,
uint32_t prop_count,
@ -247,8 +261,10 @@ inline void poet::ChemistryModule::MasterSendPkgs(
send_buffer[end_of_wp + 2] = dt;
// current time of simulation (age) in seconds
send_buffer[end_of_wp + 3] = this->simtime;
// placeholder for work_package_count
send_buffer[end_of_wp + 4] = 0.;
// current work package start location in field
uint32_t wp_start_index = std::accumulate(wp_sizes_vector.begin(), std::next(wp_sizes_vector.begin(), count_pkgs), 0);
send_buffer[end_of_wp + 4] = wp_start_index;
/* ATTENTION Worker p has rank p+1 */
// MPI_Send(send_buffer, end_of_wp + BUFFER_OFFSET, MPI_DOUBLE, p + 1,
@ -352,8 +368,21 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
int pkg_to_send, pkg_to_recv;
int free_workers;
int i_pkgs;
int ftype;
int ftype = CHEM_WORK_LOOP;
const std::vector<uint32_t> wp_sizes_vector =
CalculateWPSizesVector(this->n_cells, this->wp_size);
if (this->ai_surrogate_enabled) {
ftype = CHEM_AI_BCAST_VALIDITY;
PropagateFunctionType(ftype);
this->ai_surrogate_validity_vector = shuffleVector(this->ai_surrogate_validity_vector,
this->n_cells,
wp_sizes_vector.size());
ChemBCast(&this->ai_surrogate_validity_vector.front(), this->n_cells, MPI_INT);
}
ftype = CHEM_WORK_LOOP;
PropagateFunctionType(ftype);
MPI_Barrier(this->group_comm);
@ -363,9 +392,6 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
/* start time measurement of sequential part */
seq_a = MPI_Wtime();
const std::vector<uint32_t> wp_sizes_vector =
CalculateWPSizesVector(this->n_cells, this->wp_size);
/* shuffle grid */
// grid.shuffleAndExport(mpi_buffer);
std::vector<double> mpi_buffer =

View File

@ -47,6 +47,14 @@ void poet::ChemistryModule::WorkerLoop() {
switch (func_type) {
case CHEM_FIELD_INIT: {
ChemBCast(&this->prop_count, 1, MPI_UINT32_T);
if (this->ai_surrogate_enabled) {
this->ai_surrogate_validity_vector.reserve(this->n_cells);
}
break;
}
case CHEM_AI_BCAST_VALIDITY: {
// Receive the index vector of valid ai surrogate predictions
MPI_Bcast(&this->ai_surrogate_validity_vector.front(), this->n_cells, MPI_INT, 0, this->group_comm);
break;
}
case CHEM_WORK_LOOP: {
@ -118,7 +126,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
uint32_t iteration;
double dt;
double current_sim_time;
uint32_t wp_start_index;
int count = double_count;
std::vector<double> mpi_buffer(count);
@ -170,6 +178,16 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
interp->tryInterpolation(s_curr_wp);
}
if (this->ai_surrogate_enabled) {
// Map valid predictions from the ai surrogate in the workpackage
for (int i = 0; i < s_curr_wp.size; i++) {
if (this->ai_surrogate_validity_vector[wp_start_index + i] == 1) {
s_curr_wp.mapping[i] = CHEM_AISURR;
}
}
}
phreeqc_time_start = MPI_Wtime();
WorkerRunWorkPackage(s_curr_wp, current_sim_time, dt);

View File

@ -32,6 +32,11 @@ void InitialList::initChemistry(const Rcpp::List &chem) {
}
}
if (chem.containsElementNamed("ai_surrogate_input_script")) {
std::string ai_surrogate_input_script_path = chem["ai_surrogate_input_script"];
this->ai_surrogate_input_script = Rcpp::as<std::string>(Rcpp::Function("normalizePath")(Rcpp::wrap(ai_surrogate_input_script_path)));
}
this->field_header =
Rcpp::as<std::vector<std::string>>(this->initial_grid.names());
this->field_header.erase(this->field_header.begin());
@ -65,6 +70,7 @@ InitialList::ChemistryInit InitialList::getChemistryInit() const {
chem_init.dht_species = dht_species;
chem_init.interp_species = interp_species;
chem_init.ai_surrogate_input_script = ai_surrogate_input_script;
if (this->chem_hooks.size() > 0) {
if (this->chem_hooks.containsElementNamed("dht_fill")) {

View File

@ -82,6 +82,8 @@ void InitialList::importList(const Rcpp::List &setup, bool minimal) {
this->chem_hooks =
Rcpp::as<Rcpp::List>(setup[static_cast<int>(ExportList::CHEM_HOOKS)]);
this->ai_surrogate_input_script = Rcpp::as<std::string>(setup[static_cast<int>(ExportList::AI_SURROGATE_INPUT_SCRIPT)]);
}
Rcpp::List InitialList::exportList() {
@ -129,6 +131,7 @@ Rcpp::List InitialList::exportList() {
out[static_cast<int>(ExportList::CHEM_INTERP_SPECIES)] =
Rcpp::wrap(this->interp_species);
out[static_cast<int>(ExportList::CHEM_HOOKS)] = this->chem_hooks;
out[static_cast<int>(ExportList::AI_SURROGATE_INPUT_SCRIPT)] = this->ai_surrogate_input_script;
return out;
}

View File

@ -35,7 +35,7 @@ public:
void importList(const Rcpp::List &setup, bool minimal = false);
Rcpp::List exportList();
Field getInitialGrid() const { return Field(this->initial_grid); }
Field getInitialGrid() const { return Field(this->initial_grid); }
private:
RInside &R;
@ -66,8 +66,9 @@ private:
CHEM_DHT_SPECIES,
CHEM_INTERP_SPECIES,
CHEM_HOOKS,
ENUM_SIZE
};
AI_SURROGATE_INPUT_SCRIPT,
ENUM_SIZE // Hack: Last element of the enum to show enum size
};
// Grid members
static constexpr const char *grid_key = "Grid";
@ -203,6 +204,9 @@ private:
NamedVector<std::uint32_t> dht_species;
NamedVector<std::uint32_t> interp_species;
// Path to R script that the user defines in the input file
std::string ai_surrogate_input_script;
Rcpp::List chem_hooks;
@ -233,6 +237,8 @@ public:
NamedVector<std::uint32_t> dht_species;
NamedVector<std::uint32_t> interp_species;
ChemistryHookFunctions hooks;
std::string ai_surrogate_input_script;
};
ChemistryInit getChemistryInit() const;

View File

@ -146,10 +146,13 @@ ParseRet parseInitValues(char **argv, RuntimeParameters &params) {
cmdl("interp-min", 5) >> params.interp_min_entries;
cmdl("interp-bucket-entries", 20) >> params.interp_bucket_entries;
params.use_ai_surrogate = cmdl["ai-surrogate"];
if (MY_RANK == 0) {
// MSG("Complete results storage is " + BOOL_PRINT(simparams.store_result));
MSG("Work Package Size: " + std::to_string(params.work_package_size));
MSG("DHT is " + BOOL_PRINT(params.use_dht));
MSG("AI Surrogate is " + BOOL_PRINT(params.use_ai_surrogate));
if (params.use_dht) {
// MSG("DHT strategy is " + std::to_string(simparams.dht_strategy));
@ -253,9 +256,9 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
if (params.print_progressbar) {
chem.setProgressBarPrintout(true);
}
R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps());
/* SIMULATION LOOP */
double dSimTime{0};
for (uint32_t iter = 1; iter < maxiter + 1; iter++) {
double start_t = MPI_Wtime();
@ -273,12 +276,71 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
chem.getField().update(diffusion.getField());
// chem.getfield().update(diffusion.getfield());
MSG("Chemistry step");
if (params.use_ai_surrogate) {
double ai_start_t = MPI_Wtime();
// Save current values from the tug field as predictor for the ai step
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
R.parseEval(std::string(
"predictors <- setNames(data.frame(matrix(TMP, nrow=" +
std::to_string(chem.getField().GetRequestedVecSize()) + ")), TMP_PROPS)"));
R.parseEval("predictors <- predictors[ai_surrogate_species]");
// Predict
R.parseEval("predictors_scaled <- preprocess(predictors)");
R.parseEval("print('PREDICTORS:')");
R.parseEval("print(head(predictors))");
R.parseEval("prediction <- preprocess(prediction_step(model, predictors_scaled),\
backtransform = TRUE,\
outputs = TRUE)");
// Validate prediction and write valid predictions to chem field
R.parseEval("validity_vector <- validate_predictions(predictors,\
prediction)");
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
std::vector<std::vector<double>> RTempField = R.parseEval("set_valid_predictions(predictors,\
prediction,\
validity_vector)");
Field predictions_field = Field(R.parseEval("nrow(predictors)"),
RTempField,
R.parseEval("names(predictors)"));
chem.getField().update(predictions_field);
double ai_end_t = MPI_Wtime();
R["ai_prediction_time"] = ai_end_t - ai_start_t;
}
chem.simulate(dt);
/* AI surrogate iterative training*/
if (params.use_ai_surrogate) {
double ai_start_t = MPI_Wtime();
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
R.parseEval(std::string(
"targets <- setNames(data.frame(matrix(TMP, nrow=" +
std::to_string(chem.getField().GetRequestedVecSize()) + ")), TMP_PROPS)"));
R.parseEval("targets <- targets[ai_surrogate_species]");
// TODO: Check how to get the correct columns
R.parseEval("target_scaled <- preprocess(targets, outputs = TRUE)");
R.parseEval("print('TARGET:')");
R.parseEval("print(head(target_scaled))");
R.parseEval("training_step(model, predictors_scaled, target_scaled, validity_vector)");
double ai_end_t = MPI_Wtime();
R["ai_training_time"] = ai_end_t - ai_start_t;
}
// MPI_Barrier(MPI_COMM_WORLD);
double end_t = MPI_Wtime();
dSimTime += end_t - start_t;
R["totaltime"] = dSimTime;
// MDL master_iteration_end just writes on disk state_T and
// state_C after every iteration if the cmdline option
// --ignore-results is not given (and thus the R variable
@ -290,10 +352,6 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +
std::to_string(maxiter));
MSG();
// MPI_Barrier(MPI_COMM_WORLD);
double end_t = MPI_Wtime();
dSimTime += end_t - start_t;
} // END SIMULATION LOOP
Rcpp::List chem_profiling;
@ -384,16 +442,15 @@ int main(int argc, char *argv[]) {
run_params.use_interp,
run_params.interp_bucket_entries,
run_params.interp_size,
run_params.interp_min_entries};
run_params.interp_min_entries,
run_params.use_ai_surrogate};
chemistry.masterEnableSurrogates(surr_setup);
if (MY_RANK > 0) {
chemistry.WorkerLoop();
} else {
init_global_functions(R);
// R.parseEvalQ("mysetup <- setup");
// // if (MY_RANK == 0) { // get timestep vector from
// // grid_init function ... //
@ -404,6 +461,22 @@ int main(int argc, char *argv[]) {
// MDL: store all parameters
// MSG("Calling R Function to store calling parameters");
// R.parseEvalQ("StoreSetup(setup=mysetup)");
if (run_params.use_ai_surrogate) {
/* Incorporate ai surrogate from R */
R.parseEvalQ(ai_surrogate_r_library);
/* Use dht species for model input and output */
R["ai_surrogate_species"] = init_list.getChemistryInit().dht_species.getNames();
R["out_dir"] = run_params.out_dir;
const std::string ai_surrogate_input_script_path = init_list.getChemistryInit().ai_surrogate_input_script;
if (!ai_surrogate_input_script_path.empty()) {
R["ai_surrogate_base_path"] = ai_surrogate_input_script_path.substr(0, ai_surrogate_input_script_path.find_last_of('/') + 1);
R.parseEvalQ("source('" + ai_surrogate_input_script_path + "')");
}
R.parseEval("model <- initiate_model()");
R.parseEval("gpu_info()");
}
MSG("Init done on process with rank " + std::to_string(MY_RANK));

View File

@ -20,7 +20,7 @@
** Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
#pragma once
#pragma once
#include <cstdint>
#include <set>
@ -35,11 +35,11 @@ static const char *poet_version = "@POET_VERSION@";
static const inline std::string kin_r_library = R"(@R_KIN_LIB@)";
static const inline std::string init_r_library = R"(@R_INIT_LIB@)";
static const inline std::string ai_surrogate_r_library = R"(@R_AI_SURROGATE_LIB@)";
static const inline std::string r_runtime_parameters = "mysetup";
const std::set<std::string> flaglist{"ignore-result", "dht", "P", "progress",
"interp"};
"interp", "ai-surrogate"};
const std::set<std::string> paramlist{
"work-package-size", "dht-strategy", "dht-size", "dht-snaps",
"dht-file", "interp-size", "interp-min", "interp-bucket-entries"};
@ -66,6 +66,7 @@ struct RuntimeParameters {
std::uint32_t interp_min_entries;
std::uint32_t interp_bucket_entries;
bool use_ai_surrogate;
struct ChemistryParams {
// std::string database_path;
// std::string input_script;