diff --git a/R_lib/ai_surrogate_model.R b/R_lib/ai_surrogate_model.R index 0d8794c8a..860ee0d83 100644 --- a/R_lib/ai_surrogate_model.R +++ b/R_lib/ai_surrogate_model.R @@ -5,8 +5,8 @@ ## in the variable "ai_surrogate_input_script". See the barite_200.R file as an ## example and the general README for more information. -## library(keras3) -## library(tensorflow) +require(keras3) +require(tensorflow) initiate_model <- function() { hidden_layers <- c(48, 96, 24) diff --git a/src/Init/ChemistryInit.cpp b/src/Init/ChemistryInit.cpp index 90631fd01..3c7a9871c 100644 --- a/src/Init/ChemistryInit.cpp +++ b/src/Init/ChemistryInit.cpp @@ -40,14 +40,16 @@ void InitialList::initChemistry(const Rcpp::List &chem) { std::ifstream file(ai_surrogate_input_script_path); if (!file.is_open()) { // print error message and return - Rcpp::Rcerr << "AI surroghate input script was not found at: " << ai_surrogate_input_script_path << std::endl; + Rcpp::Rcerr << "AI surrogate input script was not found at: " << ai_surrogate_input_script_path << std::endl; } std::stringstream buffer; buffer << file.rdbuf(); std::string fileContent = buffer.str(); file.close(); - + + // Get base path + ai_surrogate_input_script_path = ai_surrogate_input_script_path.substr(0, ai_surrogate_input_script_path.find_last_of('/') + 1); // Add the filepath as a global variable in R to enable relative filepaths in the R script fileContent += "\nai_surrogate_base_path <- \"" + ai_surrogate_input_script_path + "\""; diff --git a/src/poet.cpp b/src/poet.cpp index a8cf92e9d..6a16b177a 100644 --- a/src/poet.cpp +++ b/src/poet.cpp @@ -463,7 +463,6 @@ int main(int argc, char *argv[]) { *global_rt_setup = master_init_R.value()(*global_rt_setup, run_params.out_dir, init_list.getInitialGrid().asSEXP()); - // MDL: store all parameters // MSG("Calling R Function to store calling parameters"); // R.parseEvalQ("StoreSetup(setup=mysetup)"); @@ -476,17 +475,14 @@ int main(int argc, char *argv[]) { const std::string ai_surrogate_input_script = init_list.getChemistryInit().ai_surrogate_input_script; - if (!ai_surrogate_input_script_path.empty()) { - R["ai_surrogate_base_path"] = ai_surrogate_input_script_path.substr(0, ai_surrogate_input_script_path.find_last_of('/') + 1); - - MSG("AI: sourcing user-provided script"); - R.parseEvalQ("source('" + ai_surrogate_input_script_path + "')"); - } - MSG("AI: initialize AI model"); - R.parseEval("model <- initiate_model()"); + MSG("AI: sourcing user-provided script"); + R.parseEvalQ(ai_surrogate_input_script); + + MSG("AI: initialize AI model"); + R.parseEval("model <- initiate_model()"); R.parseEval("gpu_info()"); - } - + } + MSG("Init done on process with rank " + std::to_string(MY_RANK)); // MPI_Barrier(MPI_COMM_WORLD);