MDL: some fixes and some more output to make AI run

This commit is contained in:
Marco De Lucia 2024-05-30 11:32:08 +02:00
parent 99d0b8c70d
commit d00369def7
2 changed files with 26 additions and 11 deletions

View File

@ -5,8 +5,8 @@
## in the variable "ai_surrogate_input_script". See the barite_200.R file as an
## example and the general README for more information.
library(keras)
library(tensorflow)
## library(keras3)
## library(tensorflow)
initiate_model <- function() {
hidden_layers <- c(48, 96, 24)
@ -54,6 +54,10 @@ preprocess <- function(df, backtransform = FALSE, outputs = FALSE) {
return(df)
}
postprocess <- function(df, backtransform = TRUE, outputs = TRUE) {
return(df)
}
set_valid_predictions <- function(temp_field, prediction, validity) {
temp_field[validity == 1, ] <- prediction[validity == 1, ]
return(temp_field)

View File

@ -286,25 +286,36 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
std::to_string(chem.getField().GetRequestedVecSize()) + ")), TMP_PROPS)"));
R.parseEval("predictors <- predictors[ai_surrogate_species]");
// Predict
// Apply preprocessing
MSG("AI Preprocessing");
R.parseEval("predictors_scaled <- preprocess(predictors)");
R.parseEval("prediction <- preprocess(prediction_step(model, predictors_scaled),\
backtransform = TRUE,\
outputs = TRUE)");
// Predict
MSG("AI Predict");
R.parseEval("aipreds_scaled <- prediction_step(model, predictors_scaled)");
// Apply postprocessing
MSG("AI Postprocesing");
R.parseEval("aipreds <- postprocess(aipreds_scaled)");
// Validate prediction and write valid predictions to chem field
R.parseEval("validity_vector <- validate_predictions(predictors,\
prediction)");
MSG("AI Validate");
R.parseEval("validity_vector <- validate_predictions(predictors, aipreds)");
MSG("AI Marking accepted");
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
MSG("AI TempField");
std::vector<std::vector<double>> RTempField = R.parseEval("set_valid_predictions(predictors,\
prediction,\
aipreds,\
validity_vector)");
MSG("AI Set Field");
Field predictions_field = Field(R.parseEval("nrow(predictors)"),
RTempField,
R.parseEval("names(predictors)"));
R.parseEval("colnames(predictors)"));
MSG("AI Update");
chem.getField().update(predictions_field);
double ai_end_t = MPI_Wtime();
R["ai_prediction_time"] = ai_end_t - ai_start_t;
@ -323,7 +334,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
R.parseEval("targets <- targets[ai_surrogate_species]");
// TODO: Check how to get the correct columns
R.parseEval("target_scaled <- preprocess(targets, outputs = TRUE)");
R.parseEval("target_scaled <- preprocess(targets)");
R.parseEval("training_step(model, predictors_scaled, target_scaled, validity_vector)");
double ai_end_t = MPI_Wtime();