diff --git a/R_lib/ai_surrogate_model.R b/R_lib/ai_surrogate_model.R new file mode 100644 index 000000000..268fb3098 --- /dev/null +++ b/R_lib/ai_surrogate_model.R @@ -0,0 +1,71 @@ +## This file contains default function implementations for the ai surrogate. +## To load pretrained models, use pre-/postprocessing or change hyperparameters +## it is recommended to override these functions with custom implementations via +## the input script. The path to the R-file containing the functions mus be set +## in the variable "ai_surrogate_input_script". See the barite_200.R file as an +## example and the general README for more information. + +library(keras) +library(tensorflow) + +initiate_model <- function() { + hidden_layers <- c(48, 96, 24) + activation <- "relu" + loss <- "mean_squared_error" + + input_length <- length(ai_surrogate_species) + output_length <- length(ai_surrogate_species) + ## Creates a new sequential model from scratch + model <- keras_model_sequential() + + ## Input layer defined by input data shape + model %>% layer_dense(units = input_length, + activation = activation, + input_shape = input_length, + dtype = "float32") + + for (layer_size in hidden_layers) { + model %>% layer_dense(units = layer_size, + activation = activation, + dtype = "float32") + } + + ## Output data defined by output data shape + model %>% layer_dense(units = output_length, + activation = activation, + dtype = "float32") + + model %>% compile(loss = loss, + optimizer = "adam") + return(model) +} + +gpu_info <- function() { + msgm(tf_gpu_configured()) +} + +prediction_step <- function(model, predictors) { + prediction <- predict(model, as.matrix(predictors)) + colnames(prediction) <- colnames(predictors) + return(as.data.frame(prediction)) +} + +preprocess <- function(df, backtransform = FALSE, outputs = FALSE) { + return(df) +} + +set_valid_predictions <- function(temp_field, prediction, validity) { + temp_field[validity == 1, ] <- prediction[validity == 1, ] + return(temp_field) +} + +training_step <- function(model, predictor, target, validity) { + msgm("Training:") + + x <- as.matrix(predictor) + y <- as.matrix(target[colnames(x)]) + + model %>% fit(x, y) + + model %>% save_model_tf(paste0(out_dir, "/current_model.keras")) +} diff --git a/R_lib/init_r_lib.R b/R_lib/init_r_lib.R index 640c0e07b..d0971e9da 100644 --- a/R_lib/init_r_lib.R +++ b/R_lib/init_r_lib.R @@ -53,4 +53,4 @@ add_missing_transport_species <- function(init_grid, new_names) { new_grid <- cbind(new_grid, append_df) return(new_grid) -} \ No newline at end of file +} diff --git a/R_lib/kin_r_library.R b/R_lib/kin_r_library.R index 58941a032..cb8eaecd3 100644 --- a/R_lib/kin_r_library.R +++ b/R_lib/kin_r_library.R @@ -70,14 +70,19 @@ master_iteration_end <- function(setup, state_T, state_C) { if (setup$store_result) { if (iter %in% setup$out_save) { nameout <- paste0(setup$out_dir, "/iter_", sprintf(fmt = fmt, iter), ".rds") - state_T <- data.frame(state_T, check.names = FALSE) state_C <- data.frame(state_C, check.names = FALSE) - + + ai_surrogate_info <- list( + prediction_time = if(exists("ai_prediction_time")) as.integer(ai_prediction_time) else NULL, + training_time = if(exists("ai_training_time")) as.integer(ai_training_time) else NULL, + valid_predictions = if(exists("validity_vector")) validity_vector else NULL) saveRDS(list( T = state_T, C = state_C, - simtime = as.integer(setup$simulation_time) + simtime = as.integer(setup$simulation_time), + totaltime = as.integer(totaltime), + ai_surrogate_info = ai_surrogate_info ), file = nameout) msgm("results stored in <", nameout, ">") } diff --git a/bench/CMakeLists.txt b/bench/CMakeLists.txt index 17045c31b..d407713ba 100644 --- a/bench/CMakeLists.txt +++ b/bench/CMakeLists.txt @@ -1,4 +1,3 @@ - function(ADD_BENCH_TARGET TARGET POET_BENCH_LIST RT_FILES OUT_PATH) set(bench_install_dir share/poet/${OUT_PATH}) @@ -41,4 +40,3 @@ add_custom_target(${BENCHTARGET} ALL) add_subdirectory(barite) add_subdirectory(dolo) add_subdirectory(surfex) - diff --git a/bench/barite/barite_200.R b/bench/barite/barite_200.R index 337f903eb..a6c87c749 100644 --- a/bench/barite/barite_200.R +++ b/bench/barite/barite_200.R @@ -47,7 +47,8 @@ dht_species <- c( ) chemistry_setup <- list( - dht_species = dht_species + dht_species = dht_species, + ai_surrogate_input_script = "./barite_200ai_surrogate_input_script.R" ) # Define a setup list for simulation configuration diff --git a/bench/barite/barite_200ai_surrogate_input_script.R b/bench/barite/barite_200ai_surrogate_input_script.R new file mode 100644 index 000000000..63b8f66ad --- /dev/null +++ b/bench/barite/barite_200ai_surrogate_input_script.R @@ -0,0 +1,48 @@ +## load a pretrained model from tensorflow file +## Use the global variable "ai_surrogate_base_path" when using file paths +## relative to the input script +initiate_model <- function() { + init_model <- normalizePath(paste0(ai_surrogate_base_path, + "model_min_max_float64.keras")) + return(load_model_tf(init_model)) +} + +scale_min_max <- function(x, min, max, backtransform) { + if (backtransform) { + return((x * (max - min)) + min) + } else { + return((x - min) / (max - min)) + } +} + +preprocess <- function(df, backtransform = FALSE, outputs = FALSE) { + minmax_file <- normalizePath(paste0(ai_surrogate_base_path, + "min_max_bounds.rds")) + global_minmax <- readRDS(minmax_file) + for (column in colnames(df)) { + df[column] <- lapply(df[column], + scale_min_max, + global_minmax$min[column], + global_minmax$max[column], + backtransform) + } + return(df) +} + +mass_balance <- function(predictors, prediction) { + dBa <- abs(prediction$Ba + prediction$Barite - + predictors$Ba - predictors$Barite) + dSr <- abs(prediction$Sr + prediction$Celestite - + predictors$Sr - predictors$Celestite) + return(dBa + dSr) +} + +validate_predictions <- function(predictors, prediction) { + epsilon <- 0.00003 + mb <- mass_balance(predictors, prediction) + msgm("Mass balance mean:", mean(mb)) + msgm("Mass balance variance:", var(mb)) + msgm("Rows where mass balance meets threshold", epsilon, ":", + sum(mb < epsilon)) + return(mb < epsilon) +} diff --git a/bench/barite/min_max_bounds.rds b/bench/barite/min_max_bounds.rds new file mode 100644 index 000000000..9760387d4 Binary files /dev/null and b/bench/barite/min_max_bounds.rds differ diff --git a/bench/barite/model_min_max_float64.keras b/bench/barite/model_min_max_float64.keras new file mode 100644 index 000000000..7f5b1fa5c Binary files /dev/null and b/bench/barite/model_min_max_float64.keras differ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5d77d60e0..66dc48a50 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -70,6 +70,7 @@ target_compile_definitions(POETLib PUBLIC STRICT_R_HEADERS OMPI_SKIP_MPICXX) file(READ "${PROJECT_SOURCE_DIR}/R_lib/kin_r_library.R" R_KIN_LIB ) file(READ "${PROJECT_SOURCE_DIR}/R_lib/init_r_lib.R" R_INIT_LIB) +file(READ "${PROJECT_SOURCE_DIR}/R_lib/ai_surrogate_model.R" R_AI_SURROGATE_LIB) configure_file(poet.hpp.in poet.hpp @ONLY) diff --git a/src/Chemistry/ChemistryDefs.hpp b/src/Chemistry/ChemistryDefs.hpp index 71c82edf6..cc0aa5232 100644 --- a/src/Chemistry/ChemistryDefs.hpp +++ b/src/Chemistry/ChemistryDefs.hpp @@ -6,7 +6,7 @@ namespace poet { enum DHT_PROP_TYPES { DHT_TYPE_DEFAULT, DHT_TYPE_CHARGE, DHT_TYPE_TOTAL }; -enum CHEMISTRY_OUT_SOURCE { CHEM_PQC, CHEM_DHT, CHEM_INTERP }; +enum CHEMISTRY_OUT_SOURCE { CHEM_PQC, CHEM_DHT, CHEM_INTERP, CHEM_AISURR }; struct WorkPackage { std::size_t size; diff --git a/src/Chemistry/ChemistryModule.cpp b/src/Chemistry/ChemistryModule.cpp index 155b90569..f2edf5c28 100644 --- a/src/Chemistry/ChemistryModule.cpp +++ b/src/Chemistry/ChemistryModule.cpp @@ -371,3 +371,7 @@ void poet::ChemistryModule::unshuffleField(const std::vector &in_buffer, } } } + +void poet::ChemistryModule::set_ai_surrogate_validity_vector(std::vector r_vector) { + this->ai_surrogate_validity_vector = r_vector; +} diff --git a/src/Chemistry/ChemistryModule.hpp b/src/Chemistry/ChemistryModule.hpp index 7260422fb..c06293572 100644 --- a/src/Chemistry/ChemistryModule.hpp +++ b/src/Chemistry/ChemistryModule.hpp @@ -83,6 +83,7 @@ public: std::uint32_t interp_bucket_size; std::uint32_t interp_size_mb; std::uint32_t interp_min_entries; + bool ai_surrogate_enabled; }; void masterEnableSurrogates(const SurrogateSetup &setup) { @@ -92,6 +93,7 @@ public: this->dht_enabled = setup.dht_enabled; this->interp_enabled = setup.interp_enabled; + this->ai_surrogate_enabled = setup.ai_surrogate_enabled; if (this->dht_enabled || this->interp_enabled) { this->initializeDHT(setup.dht_size_mb, this->params.dht_species); @@ -219,6 +221,11 @@ public: this->print_progessbar = enabled; }; + /** + * **Master only** Set the ai surrogate validity vector from R + */ + void set_ai_surrogate_validity_vector(std::vector r_vector); + std::vector GetWorkerInterpolationCalls() const; std::vector GetWorkerInterpolationWriteTimings() const; @@ -228,6 +235,8 @@ public: std::vector GetWorkerPHTCacheHits() const; + std::vector ai_surrogate_validity_vector; + protected: void initializeDHT(uint32_t size_mb, const NamedVector &key_species); @@ -249,7 +258,8 @@ protected: CHEM_IP_SIGNIF_VEC, CHEM_WORK_LOOP, CHEM_PERF, - CHEM_BREAK_MAIN_LOOP + CHEM_BREAK_MAIN_LOOP, + CHEM_AI_BCAST_VALIDITY }; enum { LOOP_WORK, LOOP_END }; @@ -348,6 +358,8 @@ protected: bool interp_enabled{false}; std::unique_ptr interp; + bool ai_surrogate_enabled{false}; + static constexpr uint32_t BUFFER_OFFSET = 5; inline void ChemBCast(void *buf, int count, MPI_Datatype datatype) const { diff --git a/src/Chemistry/MasterFunctions.cpp b/src/Chemistry/MasterFunctions.cpp index 9ee52be82..fce7b4139 100644 --- a/src/Chemistry/MasterFunctions.cpp +++ b/src/Chemistry/MasterFunctions.cpp @@ -159,6 +159,20 @@ std::vector poet::ChemistryModule::GetWorkerPHTCacheHits() const { return ret; } +inline std::vector shuffleVector(const std::vector &in_vector, + uint32_t size_per_prop, + uint32_t wp_count) { + std::vector out_buffer(in_vector.size()); + uint32_t write_i = 0; + for (uint32_t i = 0; i < wp_count; i++) { + for (uint32_t j = i; j < size_per_prop; j += wp_count) { + out_buffer[write_i] = in_vector[j]; + write_i++; + } + } + return out_buffer; +} + inline std::vector shuffleField(const std::vector &in_field, uint32_t size_per_prop, uint32_t prop_count, @@ -247,8 +261,10 @@ inline void poet::ChemistryModule::MasterSendPkgs( send_buffer[end_of_wp + 2] = dt; // current time of simulation (age) in seconds send_buffer[end_of_wp + 3] = this->simtime; - // placeholder for work_package_count - send_buffer[end_of_wp + 4] = 0.; + // current work package start location in field + uint32_t wp_start_index = std::accumulate(wp_sizes_vector.begin(), std::next(wp_sizes_vector.begin(), count_pkgs), 0); + send_buffer[end_of_wp + 4] = wp_start_index; + /* ATTENTION Worker p has rank p+1 */ // MPI_Send(send_buffer, end_of_wp + BUFFER_OFFSET, MPI_DOUBLE, p + 1, @@ -352,8 +368,21 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { int pkg_to_send, pkg_to_recv; int free_workers; int i_pkgs; + int ftype; - int ftype = CHEM_WORK_LOOP; + const std::vector wp_sizes_vector = + CalculateWPSizesVector(this->n_cells, this->wp_size); + + if (this->ai_surrogate_enabled) { + ftype = CHEM_AI_BCAST_VALIDITY; + PropagateFunctionType(ftype); + this->ai_surrogate_validity_vector = shuffleVector(this->ai_surrogate_validity_vector, + this->n_cells, + wp_sizes_vector.size()); + ChemBCast(&this->ai_surrogate_validity_vector.front(), this->n_cells, MPI_INT); + } + + ftype = CHEM_WORK_LOOP; PropagateFunctionType(ftype); MPI_Barrier(this->group_comm); @@ -363,9 +392,6 @@ void poet::ChemistryModule::MasterRunParallel(double dt) { /* start time measurement of sequential part */ seq_a = MPI_Wtime(); - const std::vector wp_sizes_vector = - CalculateWPSizesVector(this->n_cells, this->wp_size); - /* shuffle grid */ // grid.shuffleAndExport(mpi_buffer); std::vector mpi_buffer = diff --git a/src/Chemistry/WorkerFunctions.cpp b/src/Chemistry/WorkerFunctions.cpp index e436b6f13..0e74f0875 100644 --- a/src/Chemistry/WorkerFunctions.cpp +++ b/src/Chemistry/WorkerFunctions.cpp @@ -47,6 +47,14 @@ void poet::ChemistryModule::WorkerLoop() { switch (func_type) { case CHEM_FIELD_INIT: { ChemBCast(&this->prop_count, 1, MPI_UINT32_T); + if (this->ai_surrogate_enabled) { + this->ai_surrogate_validity_vector.reserve(this->n_cells); + } + break; + } + case CHEM_AI_BCAST_VALIDITY: { + // Receive the index vector of valid ai surrogate predictions + MPI_Bcast(&this->ai_surrogate_validity_vector.front(), this->n_cells, MPI_INT, 0, this->group_comm); break; } case CHEM_WORK_LOOP: { @@ -118,7 +126,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, uint32_t iteration; double dt; double current_sim_time; - + uint32_t wp_start_index; int count = double_count; std::vector mpi_buffer(count); @@ -170,6 +178,16 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, interp->tryInterpolation(s_curr_wp); } + if (this->ai_surrogate_enabled) { + // Map valid predictions from the ai surrogate in the workpackage + for (int i = 0; i < s_curr_wp.size; i++) { + if (this->ai_surrogate_validity_vector[wp_start_index + i] == 1) { + s_curr_wp.mapping[i] = CHEM_AISURR; + } + } + } + + phreeqc_time_start = MPI_Wtime(); WorkerRunWorkPackage(s_curr_wp, current_sim_time, dt); diff --git a/src/Init/ChemistryInit.cpp b/src/Init/ChemistryInit.cpp index 5eb0d01f4..0e512023e 100644 --- a/src/Init/ChemistryInit.cpp +++ b/src/Init/ChemistryInit.cpp @@ -32,6 +32,11 @@ void InitialList::initChemistry(const Rcpp::List &chem) { } } + if (chem.containsElementNamed("ai_surrogate_input_script")) { + std::string ai_surrogate_input_script_path = chem["ai_surrogate_input_script"]; + this->ai_surrogate_input_script = Rcpp::as(Rcpp::Function("normalizePath")(Rcpp::wrap(ai_surrogate_input_script_path))); + } + this->field_header = Rcpp::as>(this->initial_grid.names()); this->field_header.erase(this->field_header.begin()); @@ -65,6 +70,7 @@ InitialList::ChemistryInit InitialList::getChemistryInit() const { chem_init.dht_species = dht_species; chem_init.interp_species = interp_species; + chem_init.ai_surrogate_input_script = ai_surrogate_input_script; if (this->chem_hooks.size() > 0) { if (this->chem_hooks.containsElementNamed("dht_fill")) { diff --git a/src/Init/InitialList.cpp b/src/Init/InitialList.cpp index e4bae7093..1aced099e 100644 --- a/src/Init/InitialList.cpp +++ b/src/Init/InitialList.cpp @@ -82,6 +82,8 @@ void InitialList::importList(const Rcpp::List &setup, bool minimal) { this->chem_hooks = Rcpp::as(setup[static_cast(ExportList::CHEM_HOOKS)]); + + this->ai_surrogate_input_script = Rcpp::as(setup[static_cast(ExportList::AI_SURROGATE_INPUT_SCRIPT)]); } Rcpp::List InitialList::exportList() { @@ -129,6 +131,7 @@ Rcpp::List InitialList::exportList() { out[static_cast(ExportList::CHEM_INTERP_SPECIES)] = Rcpp::wrap(this->interp_species); out[static_cast(ExportList::CHEM_HOOKS)] = this->chem_hooks; + out[static_cast(ExportList::AI_SURROGATE_INPUT_SCRIPT)] = this->ai_surrogate_input_script; return out; } diff --git a/src/Init/InitialList.hpp b/src/Init/InitialList.hpp index bc9555329..3e6ae7654 100644 --- a/src/Init/InitialList.hpp +++ b/src/Init/InitialList.hpp @@ -35,7 +35,7 @@ public: void importList(const Rcpp::List &setup, bool minimal = false); Rcpp::List exportList(); - Field getInitialGrid() const { return Field(this->initial_grid); } + Field getInitialGrid() const { return Field(this->initial_grid); } private: RInside &R; @@ -66,8 +66,9 @@ private: CHEM_DHT_SPECIES, CHEM_INTERP_SPECIES, CHEM_HOOKS, - ENUM_SIZE - }; + AI_SURROGATE_INPUT_SCRIPT, + ENUM_SIZE // Hack: Last element of the enum to show enum size + }; // Grid members static constexpr const char *grid_key = "Grid"; @@ -203,6 +204,9 @@ private: NamedVector dht_species; NamedVector interp_species; + + // Path to R script that the user defines in the input file + std::string ai_surrogate_input_script; Rcpp::List chem_hooks; @@ -233,6 +237,8 @@ public: NamedVector dht_species; NamedVector interp_species; ChemistryHookFunctions hooks; + + std::string ai_surrogate_input_script; }; ChemistryInit getChemistryInit() const; diff --git a/src/poet.cpp b/src/poet.cpp index 48142cbf8..65264c746 100644 --- a/src/poet.cpp +++ b/src/poet.cpp @@ -146,10 +146,13 @@ ParseRet parseInitValues(char **argv, RuntimeParameters ¶ms) { cmdl("interp-min", 5) >> params.interp_min_entries; cmdl("interp-bucket-entries", 20) >> params.interp_bucket_entries; + params.use_ai_surrogate = cmdl["ai-surrogate"]; + if (MY_RANK == 0) { // MSG("Complete results storage is " + BOOL_PRINT(simparams.store_result)); MSG("Work Package Size: " + std::to_string(params.work_package_size)); MSG("DHT is " + BOOL_PRINT(params.use_dht)); + MSG("AI Surrogate is " + BOOL_PRINT(params.use_ai_surrogate)); if (params.use_dht) { // MSG("DHT strategy is " + std::to_string(simparams.dht_strategy)); @@ -253,9 +256,9 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms, if (params.print_progressbar) { chem.setProgressBarPrintout(true); } - + R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps()); + /* SIMULATION LOOP */ - double dSimTime{0}; for (uint32_t iter = 1; iter < maxiter + 1; iter++) { double start_t = MPI_Wtime(); @@ -273,12 +276,71 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms, chem.getField().update(diffusion.getField()); - // chem.getfield().update(diffusion.getfield()); - MSG("Chemistry step"); + if (params.use_ai_surrogate) { + double ai_start_t = MPI_Wtime(); + // Save current values from the tug field as predictor for the ai step + R["TMP"] = Rcpp::wrap(chem.getField().AsVector()); + R.parseEval(std::string( + "predictors <- setNames(data.frame(matrix(TMP, nrow=" + + std::to_string(chem.getField().GetRequestedVecSize()) + ")), TMP_PROPS)")); + R.parseEval("predictors <- predictors[ai_surrogate_species]"); + + // Predict + R.parseEval("predictors_scaled <- preprocess(predictors)"); + + R.parseEval("print('PREDICTORS:')"); + R.parseEval("print(head(predictors))"); + + R.parseEval("prediction <- preprocess(prediction_step(model, predictors_scaled),\ + backtransform = TRUE,\ + outputs = TRUE)"); + + // Validate prediction and write valid predictions to chem field + R.parseEval("validity_vector <- validate_predictions(predictors,\ + prediction)"); + chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector")); + + std::vector> RTempField = R.parseEval("set_valid_predictions(predictors,\ + prediction,\ + validity_vector)"); + + Field predictions_field = Field(R.parseEval("nrow(predictors)"), + RTempField, + R.parseEval("names(predictors)")); + chem.getField().update(predictions_field); + double ai_end_t = MPI_Wtime(); + R["ai_prediction_time"] = ai_end_t - ai_start_t; + } chem.simulate(dt); + /* AI surrogate iterative training*/ + if (params.use_ai_surrogate) { + double ai_start_t = MPI_Wtime(); + + R["TMP"] = Rcpp::wrap(chem.getField().AsVector()); + R.parseEval(std::string( + "targets <- setNames(data.frame(matrix(TMP, nrow=" + + std::to_string(chem.getField().GetRequestedVecSize()) + ")), TMP_PROPS)")); + R.parseEval("targets <- targets[ai_surrogate_species]"); + + // TODO: Check how to get the correct columns + R.parseEval("target_scaled <- preprocess(targets, outputs = TRUE)"); + + R.parseEval("print('TARGET:')"); + R.parseEval("print(head(target_scaled))"); + + R.parseEval("training_step(model, predictors_scaled, target_scaled, validity_vector)"); + double ai_end_t = MPI_Wtime(); + R["ai_training_time"] = ai_end_t - ai_start_t; + } + + // MPI_Barrier(MPI_COMM_WORLD); + double end_t = MPI_Wtime(); + dSimTime += end_t - start_t; + R["totaltime"] = dSimTime; + // MDL master_iteration_end just writes on disk state_T and // state_C after every iteration if the cmdline option // --ignore-results is not given (and thus the R variable @@ -290,10 +352,6 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms, MSG("End of *coupling* iteration " + std::to_string(iter) + "/" + std::to_string(maxiter)); MSG(); - - // MPI_Barrier(MPI_COMM_WORLD); - double end_t = MPI_Wtime(); - dSimTime += end_t - start_t; } // END SIMULATION LOOP Rcpp::List chem_profiling; @@ -384,16 +442,15 @@ int main(int argc, char *argv[]) { run_params.use_interp, run_params.interp_bucket_entries, run_params.interp_size, - run_params.interp_min_entries}; + run_params.interp_min_entries, + run_params.use_ai_surrogate}; chemistry.masterEnableSurrogates(surr_setup); if (MY_RANK > 0) { chemistry.WorkerLoop(); } else { - init_global_functions(R); - // R.parseEvalQ("mysetup <- setup"); // // if (MY_RANK == 0) { // get timestep vector from // // grid_init function ... // @@ -404,6 +461,22 @@ int main(int argc, char *argv[]) { // MDL: store all parameters // MSG("Calling R Function to store calling parameters"); // R.parseEvalQ("StoreSetup(setup=mysetup)"); + if (run_params.use_ai_surrogate) { + /* Incorporate ai surrogate from R */ + R.parseEvalQ(ai_surrogate_r_library); + /* Use dht species for model input and output */ + R["ai_surrogate_species"] = init_list.getChemistryInit().dht_species.getNames(); + R["out_dir"] = run_params.out_dir; + + const std::string ai_surrogate_input_script_path = init_list.getChemistryInit().ai_surrogate_input_script; + + if (!ai_surrogate_input_script_path.empty()) { + R["ai_surrogate_base_path"] = ai_surrogate_input_script_path.substr(0, ai_surrogate_input_script_path.find_last_of('/') + 1); + R.parseEvalQ("source('" + ai_surrogate_input_script_path + "')"); + } + R.parseEval("model <- initiate_model()"); + R.parseEval("gpu_info()"); + } MSG("Init done on process with rank " + std::to_string(MY_RANK)); diff --git a/src/poet.hpp.in b/src/poet.hpp.in index f9a86be77..cca89e264 100644 --- a/src/poet.hpp.in +++ b/src/poet.hpp.in @@ -20,7 +20,7 @@ ** Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. */ -#pragma once +#pragma once #include #include @@ -35,11 +35,11 @@ static const char *poet_version = "@POET_VERSION@"; static const inline std::string kin_r_library = R"(@R_KIN_LIB@)"; static const inline std::string init_r_library = R"(@R_INIT_LIB@)"; - +static const inline std::string ai_surrogate_r_library = R"(@R_AI_SURROGATE_LIB@)"; static const inline std::string r_runtime_parameters = "mysetup"; const std::set flaglist{"ignore-result", "dht", "P", "progress", - "interp"}; + "interp", "ai-surrogate"}; const std::set paramlist{ "work-package-size", "dht-strategy", "dht-size", "dht-snaps", "dht-file", "interp-size", "interp-min", "interp-bucket-entries"}; @@ -66,6 +66,7 @@ struct RuntimeParameters { std::uint32_t interp_min_entries; std::uint32_t interp_bucket_entries; + bool use_ai_surrogate; struct ChemistryParams { // std::string database_path; // std::string input_script;