refactor: pre/postprocess as separate functions

This commit is contained in:
straile 2024-10-19 18:44:44 +02:00
parent f746a566cc
commit 997ae32092
3 changed files with 11 additions and 75 deletions

View File

@ -252,13 +252,15 @@ R input script.
The following variables and functions must be declared:
- `model_file_path` [*string*]: Path to the Keras model file with which
the AI surrogate model is initialized.
- `validate_predictions(predictors, prediction)` [*function*]: Returns a boolean
vector of length `nrow(predictions)`. The output of this function defines
which predictions are considered valid and which are rejected. Regular
simulation will only be done for the rejected values, and the results
will be added to the training data buffer of the AI surrogate model.
Can eg. be implemented as a mass balance threshold between the predictors
and the prediction.
- `validate_predictions(predictors, prediction)` [*function*]: Returns a
boolean vector of length `nrow(predictions)`. The output of this function
defines which predictions are considered valid and which are rejected.
the predictors and predictions are passed in their original original (not
transformed) scale. Regular simulation will only be done for the rejected
values. The input data of the rejected rows and the respective true results
from simulation will be added to the training data buffer of the AI surrogate
model. Can eg. be implemented as a mass balance threshold between the
predictors and the prediction.
The following variables and functions can be declared:

View File

@ -79,16 +79,7 @@ master_iteration_end <- function(setup, state_T, state_C) {
prediction_time = if (exists("ai_prediction_time")) ai_prediction_time else NULL,
predictions_validity = if (exists("validity_vector")) validity_vector else NULL,
predictions = if (exists("predictions")) predictions else NULL,
n_training_runs = if(exists("n_training_runs")) n_training_runs else NULL,
diff_to_R = diff_to_R,
R_preprocessing = R_preprocessing,
R_preprocessed_to_cxx = R_preprocessed_to_cxx,
cxx_inference = cxx_inference,
cxx_predictions_to_R = cxx_predictions_to_R,
R_postprocessing = R_postprocessing,
R_validate = R_validate,
validity_to_cxx = validity_to_cxx,
append_to_training_buffer = append_to_training_buffer
n_training_runs = if(exists("n_training_runs")) n_training_runs else NULL
)
SaveRObj(x = list(

View File

@ -352,31 +352,15 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
if (params.use_ai_surrogate) {
double ai_start_t = MPI_Wtime();
double ai_start_steps = MPI_Wtime();
// Get current values from the tug field for the ai predictions
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
R.parseEval(std::string("predictors <- ") +
"set_field(TMP, TMP_PROPS, field_nrow, ai_surrogate_species)");
double ai_end_t = MPI_Wtime();
R["diff_to_R"] = ai_end_t - ai_start_steps;
ai_start_steps = MPI_Wtime();
// Apply preprocessing
MSG("AI Preprocessing");
R.parseEval("predictors_scaled <- preprocess(predictors)");
ai_end_t = MPI_Wtime();
R["R_preprocessing"] = ai_end_t - ai_start_steps;
ai_start_steps = MPI_Wtime();
std::vector<std::vector<double>> x = R["predictors_scaled"];
ai_end_t = MPI_Wtime();
R["R_preprocessed_to_cxx"] = ai_end_t - ai_start_steps;
ai_start_steps = MPI_Wtime();
MSG("AI: Predict");
if (params.use_Keras_predictions) { // Predict with Keras default function
R["TMP"] = Python_Keras_predict(R["predictors_scaled"], params.batch_size);
@ -385,50 +369,19 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
R["TMP"] = Eigen_predict(Eigen_model, R["predictors_scaled"], params.batch_size, &Eigen_model_mutex);
}
ai_end_t = MPI_Wtime();
R["cxx_inference"] = ai_end_t - ai_start_steps;
ai_start_steps = MPI_Wtime();
R["xyz_THROWAWAY"] = x;
ai_end_t = MPI_Wtime();
std::cout << "C++ predictions back to R: " << ai_end_t - ai_start_steps << std::endl;
R["cxx_predictions_to_R"] = ai_end_t - ai_start_steps;
ai_start_steps = MPI_Wtime();
// Apply postprocessing
MSG("AI: Postprocesing");
R.parseEval(std::string("predictions_scaled <- ") +
"set_field(TMP, ai_surrogate_species, field_nrow, ai_surrogate_species, byrow = TRUE)");
R.parseEval("predictions <- postprocess(predictions_scaled)");
ai_end_t = MPI_Wtime();
R["R_postprocessing"] = ai_end_t - ai_start_steps;
ai_start_steps = MPI_Wtime();
// Validate prediction and write valid predictions to chem field
MSG("AI: Validate");
R.parseEval("validity_vector <- validate_predictions(predictors, predictions)");
ai_end_t = MPI_Wtime();
R["R_validate"] = ai_end_t - ai_start_steps;
ai_start_steps = MPI_Wtime();
MSG("AI: Marking valid");
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
ai_end_t = MPI_Wtime();
R["validity_to_cxx"] = ai_end_t - ai_start_steps;
ai_start_steps = MPI_Wtime();
std::vector<std::vector<double>> RTempField =
R.parseEval("set_valid_predictions(predictors, predictions, validity_vector)");
@ -439,14 +392,8 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
MSG("AI: Update field with AI predictions");
chem.getField().update(predictions_field);
ai_end_t = MPI_Wtime();
R["update_field"] = ai_end_t - ai_start_steps;
ai_start_steps = MPI_Wtime();
// store time for output file
// double ai_end_t = MPI_Wtime();
ai_end_t = MPI_Wtime();
double ai_end_t = MPI_Wtime();
R["ai_prediction_time"] = ai_end_t - ai_start_t;
if (!params.disable_training) {
@ -459,9 +406,6 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
training_data_buffer_append(training_data_buffer.x, invalid_x);
training_data_buffer_mutex.unlock();
}
ai_end_t = MPI_Wtime();
R["append_to_training_buffer"] = ai_end_t - ai_start_steps;
}
// Run simulation step
@ -499,7 +443,6 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
R["n_training_runs"] = training_data_buffer.n_training_runs;
}
diffusion.getField().update(chem.getField());
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +