mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-16 04:48:23 +01:00
MDL: some fixes and some more output to make AI run
This commit is contained in:
parent
a526543cb7
commit
6238335b4d
@ -5,8 +5,8 @@
|
||||
## in the variable "ai_surrogate_input_script". See the barite_200.R file as an
|
||||
## example and the general README for more information.
|
||||
|
||||
library(keras)
|
||||
library(tensorflow)
|
||||
## library(keras3)
|
||||
## library(tensorflow)
|
||||
|
||||
initiate_model <- function() {
|
||||
hidden_layers <- c(48, 96, 24)
|
||||
@ -54,6 +54,10 @@ preprocess <- function(df, backtransform = FALSE, outputs = FALSE) {
|
||||
return(df)
|
||||
}
|
||||
|
||||
postprocess <- function(df, backtransform = TRUE, outputs = TRUE) {
|
||||
return(df)
|
||||
}
|
||||
|
||||
set_valid_predictions <- function(temp_field, prediction, validity) {
|
||||
temp_field[validity == 1, ] <- prediction[validity == 1, ]
|
||||
return(temp_field)
|
||||
|
||||
29
src/poet.cpp
29
src/poet.cpp
@ -286,25 +286,36 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
std::to_string(chem.getField().GetRequestedVecSize()) + ")), TMP_PROPS)"));
|
||||
R.parseEval("predictors <- predictors[ai_surrogate_species]");
|
||||
|
||||
// Predict
|
||||
// Apply preprocessing
|
||||
MSG("AI Preprocessing");
|
||||
R.parseEval("predictors_scaled <- preprocess(predictors)");
|
||||
|
||||
R.parseEval("prediction <- preprocess(prediction_step(model, predictors_scaled),\
|
||||
backtransform = TRUE,\
|
||||
outputs = TRUE)");
|
||||
// Predict
|
||||
MSG("AI Predict");
|
||||
R.parseEval("aipreds_scaled <- prediction_step(model, predictors_scaled)");
|
||||
|
||||
// Apply postprocessing
|
||||
MSG("AI Postprocesing");
|
||||
R.parseEval("aipreds <- postprocess(aipreds_scaled)");
|
||||
|
||||
// Validate prediction and write valid predictions to chem field
|
||||
R.parseEval("validity_vector <- validate_predictions(predictors,\
|
||||
prediction)");
|
||||
MSG("AI Validate");
|
||||
R.parseEval("validity_vector <- validate_predictions(predictors, aipreds)");
|
||||
|
||||
MSG("AI Marking accepted");
|
||||
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
|
||||
|
||||
MSG("AI TempField");
|
||||
std::vector<std::vector<double>> RTempField = R.parseEval("set_valid_predictions(predictors,\
|
||||
prediction,\
|
||||
aipreds,\
|
||||
validity_vector)");
|
||||
|
||||
MSG("AI Set Field");
|
||||
Field predictions_field = Field(R.parseEval("nrow(predictors)"),
|
||||
RTempField,
|
||||
R.parseEval("names(predictors)"));
|
||||
R.parseEval("colnames(predictors)"));
|
||||
|
||||
MSG("AI Update");
|
||||
chem.getField().update(predictions_field);
|
||||
double ai_end_t = MPI_Wtime();
|
||||
R["ai_prediction_time"] = ai_end_t - ai_start_t;
|
||||
@ -323,7 +334,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
R.parseEval("targets <- targets[ai_surrogate_species]");
|
||||
|
||||
// TODO: Check how to get the correct columns
|
||||
R.parseEval("target_scaled <- preprocess(targets, outputs = TRUE)");
|
||||
R.parseEval("target_scaled <- preprocess(targets)");
|
||||
|
||||
R.parseEval("training_step(model, predictors_scaled, target_scaled, validity_vector)");
|
||||
double ai_end_t = MPI_Wtime();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user