mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-16 04:48:23 +01:00
feat: Add AI Surrogate functions to V.03
This commit is contained in:
parent
f5f2cb4b9c
commit
95cb95998e
71
R_lib/ai_surrogate_model.R
Normal file
71
R_lib/ai_surrogate_model.R
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
## This file contains default function implementations for the ai surrogate.
|
||||||
|
## To load pretrained models, use pre-/postprocessing or change hyperparameters
|
||||||
|
## it is recommended to override these functions with custom implementations via
|
||||||
|
## the input script. The path to the R-file containing the functions mus be set
|
||||||
|
## in the variable "ai_surrogate_input_script". See the barite_200.R file as an
|
||||||
|
## example and the general README for more information.
|
||||||
|
|
||||||
|
library(keras)
|
||||||
|
library(tensorflow)
|
||||||
|
|
||||||
|
initiate_model <- function() {
|
||||||
|
hidden_layers <- c(48, 96, 24)
|
||||||
|
activation <- "relu"
|
||||||
|
loss <- "mean_squared_error"
|
||||||
|
|
||||||
|
input_length <- length(ai_surrogate_species)
|
||||||
|
output_length <- length(ai_surrogate_species)
|
||||||
|
## Creates a new sequential model from scratch
|
||||||
|
model <- keras_model_sequential()
|
||||||
|
|
||||||
|
## Input layer defined by input data shape
|
||||||
|
model %>% layer_dense(units = input_length,
|
||||||
|
activation = activation,
|
||||||
|
input_shape = input_length,
|
||||||
|
dtype = "float32")
|
||||||
|
|
||||||
|
for (layer_size in hidden_layers) {
|
||||||
|
model %>% layer_dense(units = layer_size,
|
||||||
|
activation = activation,
|
||||||
|
dtype = "float32")
|
||||||
|
}
|
||||||
|
|
||||||
|
## Output data defined by output data shape
|
||||||
|
model %>% layer_dense(units = output_length,
|
||||||
|
activation = activation,
|
||||||
|
dtype = "float32")
|
||||||
|
|
||||||
|
model %>% compile(loss = loss,
|
||||||
|
optimizer = "adam")
|
||||||
|
return(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
gpu_info <- function() {
|
||||||
|
msgm(tf_gpu_configured())
|
||||||
|
}
|
||||||
|
|
||||||
|
prediction_step <- function(model, predictors) {
|
||||||
|
prediction <- predict(model, as.matrix(predictors))
|
||||||
|
colnames(prediction) <- colnames(predictors)
|
||||||
|
return(as.data.frame(prediction))
|
||||||
|
}
|
||||||
|
|
||||||
|
preprocess <- function(df, backtransform = FALSE, outputs = FALSE) {
|
||||||
|
return(df)
|
||||||
|
}
|
||||||
|
|
||||||
|
set_valid_predictions <- function(temp_field, prediction, validity) {
|
||||||
|
temp_field[validity == 1, ] <- prediction[validity == 1, ]
|
||||||
|
return(temp_field)
|
||||||
|
}
|
||||||
|
|
||||||
|
training_step <- function(model, predictor, target, validity) {
|
||||||
|
msgm("Training:")
|
||||||
|
|
||||||
|
x <- as.matrix(predictor)
|
||||||
|
y <- as.matrix(target[colnames(x)])
|
||||||
|
|
||||||
|
model %>% fit(x, y)
|
||||||
|
|
||||||
|
model %>% save_model_tf(paste0(out_dir, "/current_model.keras"))
|
||||||
|
}
|
||||||
@ -70,14 +70,19 @@ master_iteration_end <- function(setup, state_T, state_C) {
|
|||||||
if (setup$store_result) {
|
if (setup$store_result) {
|
||||||
if (iter %in% setup$out_save) {
|
if (iter %in% setup$out_save) {
|
||||||
nameout <- paste0(setup$out_dir, "/iter_", sprintf(fmt = fmt, iter), ".rds")
|
nameout <- paste0(setup$out_dir, "/iter_", sprintf(fmt = fmt, iter), ".rds")
|
||||||
|
|
||||||
state_T <- data.frame(state_T, check.names = FALSE)
|
state_T <- data.frame(state_T, check.names = FALSE)
|
||||||
state_C <- data.frame(state_C, check.names = FALSE)
|
state_C <- data.frame(state_C, check.names = FALSE)
|
||||||
|
|
||||||
|
ai_surrogate_info <- list(
|
||||||
|
prediction_time = if(exists("ai_prediction_time")) as.integer(ai_prediction_time) else NULL,
|
||||||
|
training_time = if(exists("ai_training_time")) as.integer(ai_training_time) else NULL,
|
||||||
|
valid_predictions = if(exists("validity_vector")) validity_vector else NULL)
|
||||||
saveRDS(list(
|
saveRDS(list(
|
||||||
T = state_T,
|
T = state_T,
|
||||||
C = state_C,
|
C = state_C,
|
||||||
simtime = as.integer(setup$simulation_time)
|
simtime = as.integer(setup$simulation_time),
|
||||||
|
totaltime = as.integer(totaltime),
|
||||||
|
ai_surrogate_info = ai_surrogate_info
|
||||||
), file = nameout)
|
), file = nameout)
|
||||||
msgm("results stored in <", nameout, ">")
|
msgm("results stored in <", nameout, ">")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
function(ADD_BENCH_TARGET TARGET POET_BENCH_LIST RT_FILES OUT_PATH)
|
function(ADD_BENCH_TARGET TARGET POET_BENCH_LIST RT_FILES OUT_PATH)
|
||||||
set(bench_install_dir share/poet/${OUT_PATH})
|
set(bench_install_dir share/poet/${OUT_PATH})
|
||||||
|
|
||||||
@ -41,4 +40,3 @@ add_custom_target(${BENCHTARGET} ALL)
|
|||||||
add_subdirectory(barite)
|
add_subdirectory(barite)
|
||||||
add_subdirectory(dolo)
|
add_subdirectory(dolo)
|
||||||
add_subdirectory(surfex)
|
add_subdirectory(surfex)
|
||||||
|
|
||||||
|
|||||||
@ -47,7 +47,8 @@ dht_species <- c(
|
|||||||
)
|
)
|
||||||
|
|
||||||
chemistry_setup <- list(
|
chemistry_setup <- list(
|
||||||
dht_species = dht_species
|
dht_species = dht_species,
|
||||||
|
ai_surrogate_input_script = "./barite_200ai_surrogate_input_script.R"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Define a setup list for simulation configuration
|
# Define a setup list for simulation configuration
|
||||||
|
|||||||
48
bench/barite/barite_200ai_surrogate_input_script.R
Normal file
48
bench/barite/barite_200ai_surrogate_input_script.R
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
## load a pretrained model from tensorflow file
|
||||||
|
## Use the global variable "ai_surrogate_base_path" when using file paths
|
||||||
|
## relative to the input script
|
||||||
|
initiate_model <- function() {
|
||||||
|
init_model <- normalizePath(paste0(ai_surrogate_base_path,
|
||||||
|
"model_min_max_float64.keras"))
|
||||||
|
return(load_model_tf(init_model))
|
||||||
|
}
|
||||||
|
|
||||||
|
scale_min_max <- function(x, min, max, backtransform) {
|
||||||
|
if (backtransform) {
|
||||||
|
return((x * (max - min)) + min)
|
||||||
|
} else {
|
||||||
|
return((x - min) / (max - min))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
preprocess <- function(df, backtransform = FALSE, outputs = FALSE) {
|
||||||
|
minmax_file <- normalizePath(paste0(ai_surrogate_base_path,
|
||||||
|
"min_max_bounds.rds"))
|
||||||
|
global_minmax <- readRDS(minmax_file)
|
||||||
|
for (column in colnames(df)) {
|
||||||
|
df[column] <- lapply(df[column],
|
||||||
|
scale_min_max,
|
||||||
|
global_minmax$min[column],
|
||||||
|
global_minmax$max[column],
|
||||||
|
backtransform)
|
||||||
|
}
|
||||||
|
return(df)
|
||||||
|
}
|
||||||
|
|
||||||
|
mass_balance <- function(predictors, prediction) {
|
||||||
|
dBa <- abs(prediction$Ba + prediction$Barite -
|
||||||
|
predictors$Ba - predictors$Barite)
|
||||||
|
dSr <- abs(prediction$Sr + prediction$Celestite -
|
||||||
|
predictors$Sr - predictors$Celestite)
|
||||||
|
return(dBa + dSr)
|
||||||
|
}
|
||||||
|
|
||||||
|
validate_predictions <- function(predictors, prediction) {
|
||||||
|
epsilon <- 0.00003
|
||||||
|
mb <- mass_balance(predictors, prediction)
|
||||||
|
msgm("Mass balance mean:", mean(mb))
|
||||||
|
msgm("Mass balance variance:", var(mb))
|
||||||
|
msgm("Rows where mass balance meets threshold", epsilon, ":",
|
||||||
|
sum(mb < epsilon))
|
||||||
|
return(mb < epsilon)
|
||||||
|
}
|
||||||
BIN
bench/barite/min_max_bounds.rds
Normal file
BIN
bench/barite/min_max_bounds.rds
Normal file
Binary file not shown.
BIN
bench/barite/model_min_max_float64.keras
Normal file
BIN
bench/barite/model_min_max_float64.keras
Normal file
Binary file not shown.
@ -70,6 +70,7 @@ target_compile_definitions(POETLib PUBLIC STRICT_R_HEADERS OMPI_SKIP_MPICXX)
|
|||||||
|
|
||||||
file(READ "${PROJECT_SOURCE_DIR}/R_lib/kin_r_library.R" R_KIN_LIB )
|
file(READ "${PROJECT_SOURCE_DIR}/R_lib/kin_r_library.R" R_KIN_LIB )
|
||||||
file(READ "${PROJECT_SOURCE_DIR}/R_lib/init_r_lib.R" R_INIT_LIB)
|
file(READ "${PROJECT_SOURCE_DIR}/R_lib/init_r_lib.R" R_INIT_LIB)
|
||||||
|
file(READ "${PROJECT_SOURCE_DIR}/R_lib/ai_surrogate_model.R" R_AI_SURROGATE_LIB)
|
||||||
|
|
||||||
configure_file(poet.hpp.in poet.hpp @ONLY)
|
configure_file(poet.hpp.in poet.hpp @ONLY)
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@
|
|||||||
namespace poet {
|
namespace poet {
|
||||||
|
|
||||||
enum DHT_PROP_TYPES { DHT_TYPE_DEFAULT, DHT_TYPE_CHARGE, DHT_TYPE_TOTAL };
|
enum DHT_PROP_TYPES { DHT_TYPE_DEFAULT, DHT_TYPE_CHARGE, DHT_TYPE_TOTAL };
|
||||||
enum CHEMISTRY_OUT_SOURCE { CHEM_PQC, CHEM_DHT, CHEM_INTERP };
|
enum CHEMISTRY_OUT_SOURCE { CHEM_PQC, CHEM_DHT, CHEM_INTERP, CHEM_AISURR };
|
||||||
|
|
||||||
struct WorkPackage {
|
struct WorkPackage {
|
||||||
std::size_t size;
|
std::size_t size;
|
||||||
|
|||||||
@ -371,3 +371,7 @@ void poet::ChemistryModule::unshuffleField(const std::vector<double> &in_buffer,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void poet::ChemistryModule::set_ai_surrogate_validity_vector(std::vector<int> r_vector) {
|
||||||
|
this->ai_surrogate_validity_vector = r_vector;
|
||||||
|
}
|
||||||
|
|||||||
@ -83,6 +83,7 @@ public:
|
|||||||
std::uint32_t interp_bucket_size;
|
std::uint32_t interp_bucket_size;
|
||||||
std::uint32_t interp_size_mb;
|
std::uint32_t interp_size_mb;
|
||||||
std::uint32_t interp_min_entries;
|
std::uint32_t interp_min_entries;
|
||||||
|
bool ai_surrogate_enabled;
|
||||||
};
|
};
|
||||||
|
|
||||||
void masterEnableSurrogates(const SurrogateSetup &setup) {
|
void masterEnableSurrogates(const SurrogateSetup &setup) {
|
||||||
@ -92,6 +93,7 @@ public:
|
|||||||
|
|
||||||
this->dht_enabled = setup.dht_enabled;
|
this->dht_enabled = setup.dht_enabled;
|
||||||
this->interp_enabled = setup.interp_enabled;
|
this->interp_enabled = setup.interp_enabled;
|
||||||
|
this->ai_surrogate_enabled = setup.ai_surrogate_enabled;
|
||||||
|
|
||||||
if (this->dht_enabled || this->interp_enabled) {
|
if (this->dht_enabled || this->interp_enabled) {
|
||||||
this->initializeDHT(setup.dht_size_mb, this->params.dht_species);
|
this->initializeDHT(setup.dht_size_mb, this->params.dht_species);
|
||||||
@ -219,6 +221,11 @@ public:
|
|||||||
this->print_progessbar = enabled;
|
this->print_progessbar = enabled;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* **Master only** Set the ai surrogate validity vector from R
|
||||||
|
*/
|
||||||
|
void set_ai_surrogate_validity_vector(std::vector<int> r_vector);
|
||||||
|
|
||||||
std::vector<uint32_t> GetWorkerInterpolationCalls() const;
|
std::vector<uint32_t> GetWorkerInterpolationCalls() const;
|
||||||
|
|
||||||
std::vector<double> GetWorkerInterpolationWriteTimings() const;
|
std::vector<double> GetWorkerInterpolationWriteTimings() const;
|
||||||
@ -228,6 +235,8 @@ public:
|
|||||||
|
|
||||||
std::vector<uint32_t> GetWorkerPHTCacheHits() const;
|
std::vector<uint32_t> GetWorkerPHTCacheHits() const;
|
||||||
|
|
||||||
|
std::vector<int> ai_surrogate_validity_vector;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void initializeDHT(uint32_t size_mb,
|
void initializeDHT(uint32_t size_mb,
|
||||||
const NamedVector<std::uint32_t> &key_species);
|
const NamedVector<std::uint32_t> &key_species);
|
||||||
@ -249,7 +258,8 @@ protected:
|
|||||||
CHEM_IP_SIGNIF_VEC,
|
CHEM_IP_SIGNIF_VEC,
|
||||||
CHEM_WORK_LOOP,
|
CHEM_WORK_LOOP,
|
||||||
CHEM_PERF,
|
CHEM_PERF,
|
||||||
CHEM_BREAK_MAIN_LOOP
|
CHEM_BREAK_MAIN_LOOP,
|
||||||
|
CHEM_AI_BCAST_VALIDITY
|
||||||
};
|
};
|
||||||
|
|
||||||
enum { LOOP_WORK, LOOP_END };
|
enum { LOOP_WORK, LOOP_END };
|
||||||
@ -348,6 +358,8 @@ protected:
|
|||||||
bool interp_enabled{false};
|
bool interp_enabled{false};
|
||||||
std::unique_ptr<poet::InterpolationModule> interp;
|
std::unique_ptr<poet::InterpolationModule> interp;
|
||||||
|
|
||||||
|
bool ai_surrogate_enabled{false};
|
||||||
|
|
||||||
static constexpr uint32_t BUFFER_OFFSET = 5;
|
static constexpr uint32_t BUFFER_OFFSET = 5;
|
||||||
|
|
||||||
inline void ChemBCast(void *buf, int count, MPI_Datatype datatype) const {
|
inline void ChemBCast(void *buf, int count, MPI_Datatype datatype) const {
|
||||||
|
|||||||
@ -159,6 +159,20 @@ std::vector<uint32_t> poet::ChemistryModule::GetWorkerPHTCacheHits() const {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline std::vector<int> shuffleVector(const std::vector<int> &in_vector,
|
||||||
|
uint32_t size_per_prop,
|
||||||
|
uint32_t wp_count) {
|
||||||
|
std::vector<int> out_buffer(in_vector.size());
|
||||||
|
uint32_t write_i = 0;
|
||||||
|
for (uint32_t i = 0; i < wp_count; i++) {
|
||||||
|
for (uint32_t j = i; j < size_per_prop; j += wp_count) {
|
||||||
|
out_buffer[write_i] = in_vector[j];
|
||||||
|
write_i++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out_buffer;
|
||||||
|
}
|
||||||
|
|
||||||
inline std::vector<double> shuffleField(const std::vector<double> &in_field,
|
inline std::vector<double> shuffleField(const std::vector<double> &in_field,
|
||||||
uint32_t size_per_prop,
|
uint32_t size_per_prop,
|
||||||
uint32_t prop_count,
|
uint32_t prop_count,
|
||||||
@ -247,8 +261,10 @@ inline void poet::ChemistryModule::MasterSendPkgs(
|
|||||||
send_buffer[end_of_wp + 2] = dt;
|
send_buffer[end_of_wp + 2] = dt;
|
||||||
// current time of simulation (age) in seconds
|
// current time of simulation (age) in seconds
|
||||||
send_buffer[end_of_wp + 3] = this->simtime;
|
send_buffer[end_of_wp + 3] = this->simtime;
|
||||||
// placeholder for work_package_count
|
// current work package start location in field
|
||||||
send_buffer[end_of_wp + 4] = 0.;
|
uint32_t wp_start_index = std::accumulate(wp_sizes_vector.begin(), std::next(wp_sizes_vector.begin(), count_pkgs), 0);
|
||||||
|
send_buffer[end_of_wp + 4] = wp_start_index;
|
||||||
|
|
||||||
|
|
||||||
/* ATTENTION Worker p has rank p+1 */
|
/* ATTENTION Worker p has rank p+1 */
|
||||||
// MPI_Send(send_buffer, end_of_wp + BUFFER_OFFSET, MPI_DOUBLE, p + 1,
|
// MPI_Send(send_buffer, end_of_wp + BUFFER_OFFSET, MPI_DOUBLE, p + 1,
|
||||||
@ -352,8 +368,21 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
|||||||
int pkg_to_send, pkg_to_recv;
|
int pkg_to_send, pkg_to_recv;
|
||||||
int free_workers;
|
int free_workers;
|
||||||
int i_pkgs;
|
int i_pkgs;
|
||||||
|
int ftype;
|
||||||
|
|
||||||
int ftype = CHEM_WORK_LOOP;
|
const std::vector<uint32_t> wp_sizes_vector =
|
||||||
|
CalculateWPSizesVector(this->n_cells, this->wp_size);
|
||||||
|
|
||||||
|
if (this->ai_surrogate_enabled) {
|
||||||
|
ftype = CHEM_AI_BCAST_VALIDITY;
|
||||||
|
PropagateFunctionType(ftype);
|
||||||
|
this->ai_surrogate_validity_vector = shuffleVector(this->ai_surrogate_validity_vector,
|
||||||
|
this->n_cells,
|
||||||
|
wp_sizes_vector.size());
|
||||||
|
ChemBCast(&this->ai_surrogate_validity_vector.front(), this->n_cells, MPI_INT);
|
||||||
|
}
|
||||||
|
|
||||||
|
ftype = CHEM_WORK_LOOP;
|
||||||
PropagateFunctionType(ftype);
|
PropagateFunctionType(ftype);
|
||||||
|
|
||||||
MPI_Barrier(this->group_comm);
|
MPI_Barrier(this->group_comm);
|
||||||
@ -363,9 +392,6 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
|||||||
/* start time measurement of sequential part */
|
/* start time measurement of sequential part */
|
||||||
seq_a = MPI_Wtime();
|
seq_a = MPI_Wtime();
|
||||||
|
|
||||||
const std::vector<uint32_t> wp_sizes_vector =
|
|
||||||
CalculateWPSizesVector(this->n_cells, this->wp_size);
|
|
||||||
|
|
||||||
/* shuffle grid */
|
/* shuffle grid */
|
||||||
// grid.shuffleAndExport(mpi_buffer);
|
// grid.shuffleAndExport(mpi_buffer);
|
||||||
std::vector<double> mpi_buffer =
|
std::vector<double> mpi_buffer =
|
||||||
|
|||||||
@ -47,6 +47,14 @@ void poet::ChemistryModule::WorkerLoop() {
|
|||||||
switch (func_type) {
|
switch (func_type) {
|
||||||
case CHEM_FIELD_INIT: {
|
case CHEM_FIELD_INIT: {
|
||||||
ChemBCast(&this->prop_count, 1, MPI_UINT32_T);
|
ChemBCast(&this->prop_count, 1, MPI_UINT32_T);
|
||||||
|
if (this->ai_surrogate_enabled) {
|
||||||
|
this->ai_surrogate_validity_vector.reserve(this->n_cells);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case CHEM_AI_BCAST_VALIDITY: {
|
||||||
|
// Receive the index vector of valid ai surrogate predictions
|
||||||
|
MPI_Bcast(&this->ai_surrogate_validity_vector.front(), this->n_cells, MPI_INT, 0, this->group_comm);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case CHEM_WORK_LOOP: {
|
case CHEM_WORK_LOOP: {
|
||||||
@ -118,7 +126,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
|||||||
uint32_t iteration;
|
uint32_t iteration;
|
||||||
double dt;
|
double dt;
|
||||||
double current_sim_time;
|
double current_sim_time;
|
||||||
|
uint32_t wp_start_index;
|
||||||
int count = double_count;
|
int count = double_count;
|
||||||
std::vector<double> mpi_buffer(count);
|
std::vector<double> mpi_buffer(count);
|
||||||
|
|
||||||
@ -170,6 +178,16 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
|||||||
interp->tryInterpolation(s_curr_wp);
|
interp->tryInterpolation(s_curr_wp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (this->ai_surrogate_enabled) {
|
||||||
|
// Map valid predictions from the ai surrogate in the workpackage
|
||||||
|
for (int i = 0; i < s_curr_wp.size; i++) {
|
||||||
|
if (this->ai_surrogate_validity_vector[wp_start_index + i] == 1) {
|
||||||
|
s_curr_wp.mapping[i] = CHEM_AISURR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
phreeqc_time_start = MPI_Wtime();
|
phreeqc_time_start = MPI_Wtime();
|
||||||
|
|
||||||
WorkerRunWorkPackage(s_curr_wp, current_sim_time, dt);
|
WorkerRunWorkPackage(s_curr_wp, current_sim_time, dt);
|
||||||
|
|||||||
@ -32,6 +32,11 @@ void InitialList::initChemistry(const Rcpp::List &chem) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (chem.containsElementNamed("ai_surrogate_input_script")) {
|
||||||
|
std::string ai_surrogate_input_script_path = chem["ai_surrogate_input_script"];
|
||||||
|
this->ai_surrogate_input_script = Rcpp::as<std::string>(Rcpp::Function("normalizePath")(Rcpp::wrap(ai_surrogate_input_script_path)));
|
||||||
|
}
|
||||||
|
|
||||||
this->field_header =
|
this->field_header =
|
||||||
Rcpp::as<std::vector<std::string>>(this->initial_grid.names());
|
Rcpp::as<std::vector<std::string>>(this->initial_grid.names());
|
||||||
this->field_header.erase(this->field_header.begin());
|
this->field_header.erase(this->field_header.begin());
|
||||||
@ -65,6 +70,7 @@ InitialList::ChemistryInit InitialList::getChemistryInit() const {
|
|||||||
|
|
||||||
chem_init.dht_species = dht_species;
|
chem_init.dht_species = dht_species;
|
||||||
chem_init.interp_species = interp_species;
|
chem_init.interp_species = interp_species;
|
||||||
|
chem_init.ai_surrogate_input_script = ai_surrogate_input_script;
|
||||||
|
|
||||||
if (this->chem_hooks.size() > 0) {
|
if (this->chem_hooks.size() > 0) {
|
||||||
if (this->chem_hooks.containsElementNamed("dht_fill")) {
|
if (this->chem_hooks.containsElementNamed("dht_fill")) {
|
||||||
|
|||||||
@ -82,6 +82,8 @@ void InitialList::importList(const Rcpp::List &setup, bool minimal) {
|
|||||||
|
|
||||||
this->chem_hooks =
|
this->chem_hooks =
|
||||||
Rcpp::as<Rcpp::List>(setup[static_cast<int>(ExportList::CHEM_HOOKS)]);
|
Rcpp::as<Rcpp::List>(setup[static_cast<int>(ExportList::CHEM_HOOKS)]);
|
||||||
|
|
||||||
|
this->ai_surrogate_input_script = Rcpp::as<std::string>(setup[static_cast<int>(ExportList::AI_SURROGATE_INPUT_SCRIPT)]);
|
||||||
}
|
}
|
||||||
|
|
||||||
Rcpp::List InitialList::exportList() {
|
Rcpp::List InitialList::exportList() {
|
||||||
@ -129,6 +131,7 @@ Rcpp::List InitialList::exportList() {
|
|||||||
out[static_cast<int>(ExportList::CHEM_INTERP_SPECIES)] =
|
out[static_cast<int>(ExportList::CHEM_INTERP_SPECIES)] =
|
||||||
Rcpp::wrap(this->interp_species);
|
Rcpp::wrap(this->interp_species);
|
||||||
out[static_cast<int>(ExportList::CHEM_HOOKS)] = this->chem_hooks;
|
out[static_cast<int>(ExportList::CHEM_HOOKS)] = this->chem_hooks;
|
||||||
|
out[static_cast<int>(ExportList::AI_SURROGATE_INPUT_SCRIPT)] = this->ai_surrogate_input_script;
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -66,8 +66,9 @@ private:
|
|||||||
CHEM_DHT_SPECIES,
|
CHEM_DHT_SPECIES,
|
||||||
CHEM_INTERP_SPECIES,
|
CHEM_INTERP_SPECIES,
|
||||||
CHEM_HOOKS,
|
CHEM_HOOKS,
|
||||||
ENUM_SIZE
|
AI_SURROGATE_INPUT_SCRIPT,
|
||||||
};
|
ENUM_SIZE // Hack: Last element of the enum to show enum size
|
||||||
|
};
|
||||||
|
|
||||||
// Grid members
|
// Grid members
|
||||||
static constexpr const char *grid_key = "Grid";
|
static constexpr const char *grid_key = "Grid";
|
||||||
@ -204,6 +205,9 @@ private:
|
|||||||
|
|
||||||
NamedVector<std::uint32_t> interp_species;
|
NamedVector<std::uint32_t> interp_species;
|
||||||
|
|
||||||
|
// Path to R script that the user defines in the input file
|
||||||
|
std::string ai_surrogate_input_script;
|
||||||
|
|
||||||
Rcpp::List chem_hooks;
|
Rcpp::List chem_hooks;
|
||||||
|
|
||||||
const std::set<std::string> hook_name_list{"dht_fill", "dht_fuzz",
|
const std::set<std::string> hook_name_list{"dht_fill", "dht_fuzz",
|
||||||
@ -233,6 +237,8 @@ public:
|
|||||||
NamedVector<std::uint32_t> dht_species;
|
NamedVector<std::uint32_t> dht_species;
|
||||||
NamedVector<std::uint32_t> interp_species;
|
NamedVector<std::uint32_t> interp_species;
|
||||||
ChemistryHookFunctions hooks;
|
ChemistryHookFunctions hooks;
|
||||||
|
|
||||||
|
std::string ai_surrogate_input_script;
|
||||||
};
|
};
|
||||||
|
|
||||||
ChemistryInit getChemistryInit() const;
|
ChemistryInit getChemistryInit() const;
|
||||||
|
|||||||
93
src/poet.cpp
93
src/poet.cpp
@ -146,10 +146,13 @@ ParseRet parseInitValues(char **argv, RuntimeParameters ¶ms) {
|
|||||||
cmdl("interp-min", 5) >> params.interp_min_entries;
|
cmdl("interp-min", 5) >> params.interp_min_entries;
|
||||||
cmdl("interp-bucket-entries", 20) >> params.interp_bucket_entries;
|
cmdl("interp-bucket-entries", 20) >> params.interp_bucket_entries;
|
||||||
|
|
||||||
|
params.use_ai_surrogate = cmdl["ai-surrogate"];
|
||||||
|
|
||||||
if (MY_RANK == 0) {
|
if (MY_RANK == 0) {
|
||||||
// MSG("Complete results storage is " + BOOL_PRINT(simparams.store_result));
|
// MSG("Complete results storage is " + BOOL_PRINT(simparams.store_result));
|
||||||
MSG("Work Package Size: " + std::to_string(params.work_package_size));
|
MSG("Work Package Size: " + std::to_string(params.work_package_size));
|
||||||
MSG("DHT is " + BOOL_PRINT(params.use_dht));
|
MSG("DHT is " + BOOL_PRINT(params.use_dht));
|
||||||
|
MSG("AI Surrogate is " + BOOL_PRINT(params.use_ai_surrogate));
|
||||||
|
|
||||||
if (params.use_dht) {
|
if (params.use_dht) {
|
||||||
// MSG("DHT strategy is " + std::to_string(simparams.dht_strategy));
|
// MSG("DHT strategy is " + std::to_string(simparams.dht_strategy));
|
||||||
@ -253,9 +256,9 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
|||||||
if (params.print_progressbar) {
|
if (params.print_progressbar) {
|
||||||
chem.setProgressBarPrintout(true);
|
chem.setProgressBarPrintout(true);
|
||||||
}
|
}
|
||||||
|
R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps());
|
||||||
|
|
||||||
/* SIMULATION LOOP */
|
/* SIMULATION LOOP */
|
||||||
|
|
||||||
double dSimTime{0};
|
double dSimTime{0};
|
||||||
for (uint32_t iter = 1; iter < maxiter + 1; iter++) {
|
for (uint32_t iter = 1; iter < maxiter + 1; iter++) {
|
||||||
double start_t = MPI_Wtime();
|
double start_t = MPI_Wtime();
|
||||||
@ -273,12 +276,71 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
|||||||
|
|
||||||
chem.getField().update(diffusion.getField());
|
chem.getField().update(diffusion.getField());
|
||||||
|
|
||||||
// chem.getfield().update(diffusion.getfield());
|
|
||||||
|
|
||||||
MSG("Chemistry step");
|
MSG("Chemistry step");
|
||||||
|
if (params.use_ai_surrogate) {
|
||||||
|
double ai_start_t = MPI_Wtime();
|
||||||
|
// Save current values from the tug field as predictor for the ai step
|
||||||
|
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
|
||||||
|
R.parseEval(std::string(
|
||||||
|
"predictors <- setNames(data.frame(matrix(TMP, nrow=" +
|
||||||
|
std::to_string(chem.getField().GetRequestedVecSize()) + ")), TMP_PROPS)"));
|
||||||
|
R.parseEval("predictors <- predictors[ai_surrogate_species]");
|
||||||
|
|
||||||
|
// Predict
|
||||||
|
R.parseEval("predictors_scaled <- preprocess(predictors)");
|
||||||
|
|
||||||
|
R.parseEval("print('PREDICTORS:')");
|
||||||
|
R.parseEval("print(head(predictors))");
|
||||||
|
|
||||||
|
R.parseEval("prediction <- preprocess(prediction_step(model, predictors_scaled),\
|
||||||
|
backtransform = TRUE,\
|
||||||
|
outputs = TRUE)");
|
||||||
|
|
||||||
|
// Validate prediction and write valid predictions to chem field
|
||||||
|
R.parseEval("validity_vector <- validate_predictions(predictors,\
|
||||||
|
prediction)");
|
||||||
|
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
|
||||||
|
|
||||||
|
std::vector<std::vector<double>> RTempField = R.parseEval("set_valid_predictions(predictors,\
|
||||||
|
prediction,\
|
||||||
|
validity_vector)");
|
||||||
|
|
||||||
|
Field predictions_field = Field(R.parseEval("nrow(predictors)"),
|
||||||
|
RTempField,
|
||||||
|
R.parseEval("names(predictors)"));
|
||||||
|
chem.getField().update(predictions_field);
|
||||||
|
double ai_end_t = MPI_Wtime();
|
||||||
|
R["ai_prediction_time"] = ai_end_t - ai_start_t;
|
||||||
|
}
|
||||||
|
|
||||||
chem.simulate(dt);
|
chem.simulate(dt);
|
||||||
|
|
||||||
|
/* AI surrogate iterative training*/
|
||||||
|
if (params.use_ai_surrogate) {
|
||||||
|
double ai_start_t = MPI_Wtime();
|
||||||
|
|
||||||
|
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
|
||||||
|
R.parseEval(std::string(
|
||||||
|
"targets <- setNames(data.frame(matrix(TMP, nrow=" +
|
||||||
|
std::to_string(chem.getField().GetRequestedVecSize()) + ")), TMP_PROPS)"));
|
||||||
|
R.parseEval("targets <- targets[ai_surrogate_species]");
|
||||||
|
|
||||||
|
// TODO: Check how to get the correct columns
|
||||||
|
R.parseEval("target_scaled <- preprocess(targets, outputs = TRUE)");
|
||||||
|
|
||||||
|
R.parseEval("print('TARGET:')");
|
||||||
|
R.parseEval("print(head(target_scaled))");
|
||||||
|
|
||||||
|
R.parseEval("training_step(model, predictors_scaled, target_scaled, validity_vector)");
|
||||||
|
double ai_end_t = MPI_Wtime();
|
||||||
|
R["ai_training_time"] = ai_end_t - ai_start_t;
|
||||||
|
}
|
||||||
|
|
||||||
|
// MPI_Barrier(MPI_COMM_WORLD);
|
||||||
|
double end_t = MPI_Wtime();
|
||||||
|
dSimTime += end_t - start_t;
|
||||||
|
R["totaltime"] = dSimTime;
|
||||||
|
|
||||||
// MDL master_iteration_end just writes on disk state_T and
|
// MDL master_iteration_end just writes on disk state_T and
|
||||||
// state_C after every iteration if the cmdline option
|
// state_C after every iteration if the cmdline option
|
||||||
// --ignore-results is not given (and thus the R variable
|
// --ignore-results is not given (and thus the R variable
|
||||||
@ -290,10 +352,6 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
|||||||
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +
|
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +
|
||||||
std::to_string(maxiter));
|
std::to_string(maxiter));
|
||||||
MSG();
|
MSG();
|
||||||
|
|
||||||
// MPI_Barrier(MPI_COMM_WORLD);
|
|
||||||
double end_t = MPI_Wtime();
|
|
||||||
dSimTime += end_t - start_t;
|
|
||||||
} // END SIMULATION LOOP
|
} // END SIMULATION LOOP
|
||||||
|
|
||||||
Rcpp::List chem_profiling;
|
Rcpp::List chem_profiling;
|
||||||
@ -384,16 +442,15 @@ int main(int argc, char *argv[]) {
|
|||||||
run_params.use_interp,
|
run_params.use_interp,
|
||||||
run_params.interp_bucket_entries,
|
run_params.interp_bucket_entries,
|
||||||
run_params.interp_size,
|
run_params.interp_size,
|
||||||
run_params.interp_min_entries};
|
run_params.interp_min_entries,
|
||||||
|
run_params.use_ai_surrogate};
|
||||||
|
|
||||||
chemistry.masterEnableSurrogates(surr_setup);
|
chemistry.masterEnableSurrogates(surr_setup);
|
||||||
|
|
||||||
if (MY_RANK > 0) {
|
if (MY_RANK > 0) {
|
||||||
chemistry.WorkerLoop();
|
chemistry.WorkerLoop();
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
init_global_functions(R);
|
init_global_functions(R);
|
||||||
|
|
||||||
// R.parseEvalQ("mysetup <- setup");
|
// R.parseEvalQ("mysetup <- setup");
|
||||||
// // if (MY_RANK == 0) { // get timestep vector from
|
// // if (MY_RANK == 0) { // get timestep vector from
|
||||||
// // grid_init function ... //
|
// // grid_init function ... //
|
||||||
@ -404,6 +461,22 @@ int main(int argc, char *argv[]) {
|
|||||||
// MDL: store all parameters
|
// MDL: store all parameters
|
||||||
// MSG("Calling R Function to store calling parameters");
|
// MSG("Calling R Function to store calling parameters");
|
||||||
// R.parseEvalQ("StoreSetup(setup=mysetup)");
|
// R.parseEvalQ("StoreSetup(setup=mysetup)");
|
||||||
|
if (run_params.use_ai_surrogate) {
|
||||||
|
/* Incorporate ai surrogate from R */
|
||||||
|
R.parseEvalQ(ai_surrogate_r_library);
|
||||||
|
/* Use dht species for model input and output */
|
||||||
|
R["ai_surrogate_species"] = init_list.getChemistryInit().dht_species.getNames();
|
||||||
|
R["out_dir"] = run_params.out_dir;
|
||||||
|
|
||||||
|
const std::string ai_surrogate_input_script_path = init_list.getChemistryInit().ai_surrogate_input_script;
|
||||||
|
|
||||||
|
if (!ai_surrogate_input_script_path.empty()) {
|
||||||
|
R["ai_surrogate_base_path"] = ai_surrogate_input_script_path.substr(0, ai_surrogate_input_script_path.find_last_of('/') + 1);
|
||||||
|
R.parseEvalQ("source('" + ai_surrogate_input_script_path + "')");
|
||||||
|
}
|
||||||
|
R.parseEval("model <- initiate_model()");
|
||||||
|
R.parseEval("gpu_info()");
|
||||||
|
}
|
||||||
|
|
||||||
MSG("Init done on process with rank " + std::to_string(MY_RANK));
|
MSG("Init done on process with rank " + std::to_string(MY_RANK));
|
||||||
|
|
||||||
|
|||||||
@ -35,11 +35,11 @@ static const char *poet_version = "@POET_VERSION@";
|
|||||||
static const inline std::string kin_r_library = R"(@R_KIN_LIB@)";
|
static const inline std::string kin_r_library = R"(@R_KIN_LIB@)";
|
||||||
|
|
||||||
static const inline std::string init_r_library = R"(@R_INIT_LIB@)";
|
static const inline std::string init_r_library = R"(@R_INIT_LIB@)";
|
||||||
|
static const inline std::string ai_surrogate_r_library = R"(@R_AI_SURROGATE_LIB@)";
|
||||||
static const inline std::string r_runtime_parameters = "mysetup";
|
static const inline std::string r_runtime_parameters = "mysetup";
|
||||||
|
|
||||||
const std::set<std::string> flaglist{"ignore-result", "dht", "P", "progress",
|
const std::set<std::string> flaglist{"ignore-result", "dht", "P", "progress",
|
||||||
"interp"};
|
"interp", "ai-surrogate"};
|
||||||
const std::set<std::string> paramlist{
|
const std::set<std::string> paramlist{
|
||||||
"work-package-size", "dht-strategy", "dht-size", "dht-snaps",
|
"work-package-size", "dht-strategy", "dht-size", "dht-snaps",
|
||||||
"dht-file", "interp-size", "interp-min", "interp-bucket-entries"};
|
"dht-file", "interp-size", "interp-min", "interp-bucket-entries"};
|
||||||
@ -66,6 +66,7 @@ struct RuntimeParameters {
|
|||||||
std::uint32_t interp_min_entries;
|
std::uint32_t interp_min_entries;
|
||||||
std::uint32_t interp_bucket_entries;
|
std::uint32_t interp_bucket_entries;
|
||||||
|
|
||||||
|
bool use_ai_surrogate;
|
||||||
struct ChemistryParams {
|
struct ChemistryParams {
|
||||||
// std::string database_path;
|
// std::string database_path;
|
||||||
// std::string input_script;
|
// std::string input_script;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user