diff --git a/R_lib/ai_surrogate_model.R b/R_lib/ai_surrogate_model.R index 268fb3098..0d8794c8a 100644 --- a/R_lib/ai_surrogate_model.R +++ b/R_lib/ai_surrogate_model.R @@ -5,8 +5,8 @@ ## in the variable "ai_surrogate_input_script". See the barite_200.R file as an ## example and the general README for more information. -library(keras) -library(tensorflow) +## library(keras3) +## library(tensorflow) initiate_model <- function() { hidden_layers <- c(48, 96, 24) @@ -54,6 +54,10 @@ preprocess <- function(df, backtransform = FALSE, outputs = FALSE) { return(df) } +postprocess <- function(df, backtransform = TRUE, outputs = TRUE) { + return(df) +} + set_valid_predictions <- function(temp_field, prediction, validity) { temp_field[validity == 1, ] <- prediction[validity == 1, ] return(temp_field) diff --git a/src/poet.cpp b/src/poet.cpp index 5ee6260fc..5175732a5 100644 --- a/src/poet.cpp +++ b/src/poet.cpp @@ -286,25 +286,36 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms, std::to_string(chem.getField().GetRequestedVecSize()) + ")), TMP_PROPS)")); R.parseEval("predictors <- predictors[ai_surrogate_species]"); - // Predict + // Apply preprocessing + MSG("AI Preprocessing"); R.parseEval("predictors_scaled <- preprocess(predictors)"); - R.parseEval("prediction <- preprocess(prediction_step(model, predictors_scaled),\ - backtransform = TRUE,\ - outputs = TRUE)"); + // Predict + MSG("AI Predict"); + R.parseEval("aipreds_scaled <- prediction_step(model, predictors_scaled)"); + + // Apply postprocessing + MSG("AI Postprocesing"); + R.parseEval("aipreds <- postprocess(aipreds_scaled)"); // Validate prediction and write valid predictions to chem field - R.parseEval("validity_vector <- validate_predictions(predictors,\ - prediction)"); + MSG("AI Validate"); + R.parseEval("validity_vector <- validate_predictions(predictors, aipreds)"); + + MSG("AI Marking accepted"); chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector")); + MSG("AI TempField"); std::vector> RTempField = R.parseEval("set_valid_predictions(predictors,\ - prediction,\ + aipreds,\ validity_vector)"); + MSG("AI Set Field"); Field predictions_field = Field(R.parseEval("nrow(predictors)"), RTempField, - R.parseEval("names(predictors)")); + R.parseEval("colnames(predictors)")); + + MSG("AI Update"); chem.getField().update(predictions_field); double ai_end_t = MPI_Wtime(); R["ai_prediction_time"] = ai_end_t - ai_start_t; @@ -323,7 +334,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms, R.parseEval("targets <- targets[ai_surrogate_species]"); // TODO: Check how to get the correct columns - R.parseEval("target_scaled <- preprocess(targets, outputs = TRUE)"); + R.parseEval("target_scaled <- preprocess(targets)"); R.parseEval("training_step(model, predictors_scaled, target_scaled, validity_vector)"); double ai_end_t = MPI_Wtime();