mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-16 04:48:23 +01:00
refactor: pre/postprocess as separate functions
This commit is contained in:
parent
f746a566cc
commit
997ae32092
16
README.md
16
README.md
@ -252,13 +252,15 @@ R input script.
|
||||
The following variables and functions must be declared:
|
||||
- `model_file_path` [*string*]: Path to the Keras model file with which
|
||||
the AI surrogate model is initialized.
|
||||
- `validate_predictions(predictors, prediction)` [*function*]: Returns a boolean
|
||||
vector of length `nrow(predictions)`. The output of this function defines
|
||||
which predictions are considered valid and which are rejected. Regular
|
||||
simulation will only be done for the rejected values, and the results
|
||||
will be added to the training data buffer of the AI surrogate model.
|
||||
Can eg. be implemented as a mass balance threshold between the predictors
|
||||
and the prediction.
|
||||
- `validate_predictions(predictors, prediction)` [*function*]: Returns a
|
||||
boolean vector of length `nrow(predictions)`. The output of this function
|
||||
defines which predictions are considered valid and which are rejected.
|
||||
the predictors and predictions are passed in their original original (not
|
||||
transformed) scale. Regular simulation will only be done for the rejected
|
||||
values. The input data of the rejected rows and the respective true results
|
||||
from simulation will be added to the training data buffer of the AI surrogate
|
||||
model. Can eg. be implemented as a mass balance threshold between the
|
||||
predictors and the prediction.
|
||||
|
||||
|
||||
The following variables and functions can be declared:
|
||||
|
||||
@ -79,16 +79,7 @@ master_iteration_end <- function(setup, state_T, state_C) {
|
||||
prediction_time = if (exists("ai_prediction_time")) ai_prediction_time else NULL,
|
||||
predictions_validity = if (exists("validity_vector")) validity_vector else NULL,
|
||||
predictions = if (exists("predictions")) predictions else NULL,
|
||||
n_training_runs = if(exists("n_training_runs")) n_training_runs else NULL,
|
||||
diff_to_R = diff_to_R,
|
||||
R_preprocessing = R_preprocessing,
|
||||
R_preprocessed_to_cxx = R_preprocessed_to_cxx,
|
||||
cxx_inference = cxx_inference,
|
||||
cxx_predictions_to_R = cxx_predictions_to_R,
|
||||
R_postprocessing = R_postprocessing,
|
||||
R_validate = R_validate,
|
||||
validity_to_cxx = validity_to_cxx,
|
||||
append_to_training_buffer = append_to_training_buffer
|
||||
n_training_runs = if(exists("n_training_runs")) n_training_runs else NULL
|
||||
)
|
||||
|
||||
SaveRObj(x = list(
|
||||
|
||||
59
src/poet.cpp
59
src/poet.cpp
@ -352,31 +352,15 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
|
||||
if (params.use_ai_surrogate) {
|
||||
double ai_start_t = MPI_Wtime();
|
||||
double ai_start_steps = MPI_Wtime();
|
||||
// Get current values from the tug field for the ai predictions
|
||||
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
|
||||
R.parseEval(std::string("predictors <- ") +
|
||||
"set_field(TMP, TMP_PROPS, field_nrow, ai_surrogate_species)");
|
||||
|
||||
|
||||
double ai_end_t = MPI_Wtime();
|
||||
R["diff_to_R"] = ai_end_t - ai_start_steps;
|
||||
ai_start_steps = MPI_Wtime();
|
||||
|
||||
// Apply preprocessing
|
||||
MSG("AI Preprocessing");
|
||||
R.parseEval("predictors_scaled <- preprocess(predictors)");
|
||||
|
||||
ai_end_t = MPI_Wtime();
|
||||
R["R_preprocessing"] = ai_end_t - ai_start_steps;
|
||||
ai_start_steps = MPI_Wtime();
|
||||
|
||||
|
||||
std::vector<std::vector<double>> x = R["predictors_scaled"];
|
||||
ai_end_t = MPI_Wtime();
|
||||
R["R_preprocessed_to_cxx"] = ai_end_t - ai_start_steps;
|
||||
ai_start_steps = MPI_Wtime();
|
||||
|
||||
MSG("AI: Predict");
|
||||
if (params.use_Keras_predictions) { // Predict with Keras default function
|
||||
R["TMP"] = Python_Keras_predict(R["predictors_scaled"], params.batch_size);
|
||||
@ -385,50 +369,19 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
R["TMP"] = Eigen_predict(Eigen_model, R["predictors_scaled"], params.batch_size, &Eigen_model_mutex);
|
||||
}
|
||||
|
||||
ai_end_t = MPI_Wtime();
|
||||
R["cxx_inference"] = ai_end_t - ai_start_steps;
|
||||
ai_start_steps = MPI_Wtime();
|
||||
|
||||
R["xyz_THROWAWAY"] = x;
|
||||
ai_end_t = MPI_Wtime();
|
||||
std::cout << "C++ predictions back to R: " << ai_end_t - ai_start_steps << std::endl;
|
||||
R["cxx_predictions_to_R"] = ai_end_t - ai_start_steps;
|
||||
ai_start_steps = MPI_Wtime();
|
||||
|
||||
// Apply postprocessing
|
||||
MSG("AI: Postprocesing");
|
||||
R.parseEval(std::string("predictions_scaled <- ") +
|
||||
"set_field(TMP, ai_surrogate_species, field_nrow, ai_surrogate_species, byrow = TRUE)");
|
||||
R.parseEval("predictions <- postprocess(predictions_scaled)");
|
||||
|
||||
ai_end_t = MPI_Wtime();
|
||||
R["R_postprocessing"] = ai_end_t - ai_start_steps;
|
||||
ai_start_steps = MPI_Wtime();
|
||||
|
||||
// Validate prediction and write valid predictions to chem field
|
||||
MSG("AI: Validate");
|
||||
R.parseEval("validity_vector <- validate_predictions(predictors, predictions)");
|
||||
|
||||
|
||||
|
||||
|
||||
ai_end_t = MPI_Wtime();
|
||||
R["R_validate"] = ai_end_t - ai_start_steps;
|
||||
ai_start_steps = MPI_Wtime();
|
||||
|
||||
|
||||
|
||||
|
||||
MSG("AI: Marking valid");
|
||||
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
|
||||
|
||||
|
||||
|
||||
ai_end_t = MPI_Wtime();
|
||||
R["validity_to_cxx"] = ai_end_t - ai_start_steps;
|
||||
ai_start_steps = MPI_Wtime();
|
||||
|
||||
|
||||
std::vector<std::vector<double>> RTempField =
|
||||
R.parseEval("set_valid_predictions(predictors, predictions, validity_vector)");
|
||||
|
||||
@ -439,14 +392,8 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
MSG("AI: Update field with AI predictions");
|
||||
chem.getField().update(predictions_field);
|
||||
|
||||
|
||||
ai_end_t = MPI_Wtime();
|
||||
R["update_field"] = ai_end_t - ai_start_steps;
|
||||
ai_start_steps = MPI_Wtime();
|
||||
|
||||
// store time for output file
|
||||
// double ai_end_t = MPI_Wtime();
|
||||
ai_end_t = MPI_Wtime();
|
||||
double ai_end_t = MPI_Wtime();
|
||||
R["ai_prediction_time"] = ai_end_t - ai_start_t;
|
||||
|
||||
if (!params.disable_training) {
|
||||
@ -459,9 +406,6 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
training_data_buffer_append(training_data_buffer.x, invalid_x);
|
||||
training_data_buffer_mutex.unlock();
|
||||
}
|
||||
|
||||
ai_end_t = MPI_Wtime();
|
||||
R["append_to_training_buffer"] = ai_end_t - ai_start_steps;
|
||||
}
|
||||
|
||||
// Run simulation step
|
||||
@ -499,7 +443,6 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
R["n_training_runs"] = training_data_buffer.n_training_runs;
|
||||
}
|
||||
|
||||
|
||||
diffusion.getField().update(chem.getField());
|
||||
|
||||
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user