add copy logic for non-reactive regions

Co-authored-by: Max Lübke <mluebke@uni-potsdam.de>
This commit is contained in:
Hannes Signer 2025-12-08 19:06:00 +01:00
parent 9ab84c3181
commit 655a9b493d
6 changed files with 100 additions and 48 deletions

View File

@ -37,6 +37,7 @@ ai_surrogate_species_input = c("H", "O", "Ba", "Cl", "S", "Sr", "Barite_kin", "C
ai_surrogate_species_output = c("O", "Ba", "S", "Sr", "Barite_kin", "Celestite_kin")
threshold <- list(species = "Cl", value = 1E-10)
preprocess <- function(df) {
if (!is.data.frame(df))

View File

@ -89,6 +89,7 @@ public:
std::uint32_t interp_size_mb;
std::uint32_t interp_min_entries;
bool ai_surrogate_enabled;
bool copy_non_reactive;
};
void masterEnableSurrogates(const SurrogateSetup &setup) {
@ -99,6 +100,7 @@ public:
this->dht_enabled = setup.dht_enabled;
this->interp_enabled = setup.interp_enabled;
this->ai_surrogate_enabled = setup.ai_surrogate_enabled;
this->copy_non_reactive = setup.copy_non_reactive;
this->base_totals = setup.base_totals;
@ -372,6 +374,7 @@ protected:
std::unique_ptr<poet::InterpolationModule> interp;
bool ai_surrogate_enabled{false};
bool copy_non_reactive{false};
static constexpr uint32_t BUFFER_OFFSET = 5;

View File

@ -262,10 +262,11 @@ inline void poet::ChemistryModule::MasterSendPkgs(
// current time of simulation (age) in seconds
send_buffer[end_of_wp + 3] = this->simtime;
// current work package start location in field
uint32_t wp_start_index = std::accumulate(wp_sizes_vector.begin(), std::next(wp_sizes_vector.begin(), count_pkgs), 0);
uint32_t wp_start_index =
std::accumulate(wp_sizes_vector.begin(),
std::next(wp_sizes_vector.begin(), count_pkgs), 0);
send_buffer[end_of_wp + 4] = wp_start_index;
/* ATTENTION Worker p has rank p+1 */
// MPI_Send(send_buffer, end_of_wp + BUFFER_OFFSET, MPI_DOUBLE, p + 1,
// LOOP_WORK, this->group_comm);
@ -373,14 +374,15 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
const std::vector<uint32_t> wp_sizes_vector =
CalculateWPSizesVector(this->n_cells, this->wp_size);
if (this->ai_surrogate_enabled) {
if (this->ai_surrogate_enabled || this->copy_non_reactive) {
ftype = CHEM_AI_BCAST_VALIDITY;
PropagateFunctionType(ftype);
this->ai_surrogate_validity_vector = shuffleVector(this->ai_surrogate_validity_vector,
this->n_cells,
wp_sizes_vector.size());
ChemBCast(&this->ai_surrogate_validity_vector.front(), this->n_cells, MPI_INT);
}
this->ai_surrogate_validity_vector =
shuffleVector(this->ai_surrogate_validity_vector, this->n_cells,
wp_sizes_vector.size());
ChemBCast(&this->ai_surrogate_validity_vector.front(), this->n_cells,
MPI_INT);
}
ftype = CHEM_WORK_LOOP;
PropagateFunctionType(ftype);

View File

@ -46,7 +46,7 @@ void poet::ChemistryModule::WorkerLoop() {
switch (func_type) {
case CHEM_FIELD_INIT: {
ChemBCast(&this->prop_count, 1, MPI_UINT32_T);
if (this->ai_surrogate_enabled) {
if (this->ai_surrogate_enabled || this->copy_non_reactive) {
this->ai_surrogate_validity_vector.resize(
this->n_cells); // resize statt reserve?
}
@ -179,7 +179,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
interp->tryInterpolation(s_curr_wp);
}
if (this->ai_surrogate_enabled) {
if (this->ai_surrogate_enabled || this->copy_non_reactive) {
// Map valid predictions from the ai surrogate in the workpackage
for (int i = 0; i < s_curr_wp.size; i++) {
if (this->ai_surrogate_validity_vector[wp_start_index + i] == 1) {

View File

@ -36,6 +36,7 @@
#include <Rcpp/vector/instantiation.h>
#include <algorithm>
#include <array>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <memory>
@ -162,6 +163,9 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
->check(CLI::PositiveNumber)
->default_val(RuntimeParameters::AI_BACKEND_DEFAULT);
app.add_flag("-c,--copy-non-reactive", params.copy_non_reactive_regions,
"Copy non-reactive regions instead of computing them");
app.add_flag("--rds", params.as_rds,
"Save output as .rds file instead of default .qs2");
@ -322,11 +326,12 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
// initialzie training backens only if retraining is desired
if (params.ai_backend == PYTHON_BACKEND) {
MSG("AI Surrogate with Python/keras backend enabled.")
// auto model = Python<ai_type_t>();
ai_ctx->training_backend =
std::make_unique<PythonBackend<ai_type_t>>(4 * params.batch_size);
} else if (params.ai_backend == NAA_BACKEND) {
MSG("AI Surrogate with NAA backend enabled.")
ai_ctx->training_backend =
std::make_unique<NAABackend<ai_type_t>>(20 * params.batch_size);
std::make_unique<NAABackend<ai_type_t>>(4 * params.batch_size);
}
if (!params.disable_retraining) {
@ -356,27 +361,49 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
/* run transport */
diffusion.simulate(dt);
chem.getField().update(diffusion.getField());
if (params.ai || params.copy_non_reactive_regions) {
chem.getField().update(diffusion.getField());
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
R.parseEval(
std::string("field <- setNames(data.frame(matrix(TMP, nrow=" +
std::to_string(chem.getField().GetRequestedVecSize()) +
")), TMP_PROPS)"));
R.parseEval("validity_vector <- rep(FALSE, nrow(field))");
if (params.copy_non_reactive_regions) {
R.parseEval("validity_vector <- field$Cl < 1e-14");
}
}
// MSG("Chemistry start");
if (params.ai) {
double ai_start_t = MPI_Wtime();
// Save current values from the tug field as predictor for the ai step
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
R.parseEval(
std::string("predictors <- setNames(data.frame(matrix(TMP, nrow=" +
std::to_string(chem.getField().GetRequestedVecSize()) +
")), TMP_PROPS)"));
// deep copy field
R.parseEval("predictors <- data.frame(field)");
// get only ai related species
R.parseEval("predictors <- predictors[ai_surrogate_species_input]");
// remove already copied values
R.parseEval("predictors <- predictors[!validity_vector,]");
R.parseEval(
"print(paste('Length of predictors:', length(predictors$H)))");
// store row names of predictors
R.parseEval("predictor_idx <- row.names(predictors)");
R.parseEval("print(head(predictors))");
R.parseEval("predictors_scaled <- preprocess(predictors)");
std::vector<std::vector<float>> predictors_scaled =
R["predictors_scaled"];
// FIXME: double/float conversion
std::vector<float> predictions_scaled = ai_ctx->model.predict(
predictors_scaled, params.batch_size, ai_ctx->model_semaphore);
std::vector<float> predictions_scaled =
ai_ctx->model.predict(predictors_scaled, params.batch_size,
ai_ctx->model_semaphore); // features per cell
int n_samples = R.parseEval("nrow(predictors)");
int n_output_features = ai_ctx->model.weight_matrices.back().cols();
@ -396,38 +423,50 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
MSG("AI Validation");
// FIXME: (mass balance plausible?)
R.parseEval("validity_vector <- validate_predictions(predictors, "
R.parseEval("ai_validity_vector <- validate_predictions(predictors, "
"predictions) ");
R.parseEval("print(length(predictor_idx))");
R.parseEval("print(length(ai_validity_vector))");
// get only indices where prediction was valid
R.parseEval("predictor_idx <- predictor_idx[ai_validity_vector]");
// set in global validity vector all elements to true, where prediction
// was possible
R.parseEval("validity_vector[predictor_idx] <- TRUE");
R.parseEval("print(head(validity_vector))");
MSG("AI Marking accepted");
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
MSG("AI TempField");
R.parseEval("print(ai_surrogate_species_output)");
// R.parseEval("print(head(predictors))");
std::vector<std::vector<double>> RTempField = R.parseEval(
"set_valid_predictions(predictors[ai_surrogate_species_output],\
predictions,\
validity_vector)");
// maybe row.names was overwritten by function calls ??
R.parseEval("row.names(predictions) <- row.names(predictors)");
// subset predictions to ai_validity_vector == TRUE
R.parseEval("predictions <- predictions[ai_validity_vector,]");
// merge predicted values into field stored in R
R.parseEval("field[row.names(predictions),ai_surrogate_species_output] "
"<- predictions");
MSG("AI Set Field");
Field predictions_field = Field(
R.parseEval("nrow(predictors)"), RTempField,
R.parseEval(
"colnames(predictors[ai_surrogate_species_output])")); // FIXME:
// is this
// correct?
R.parseEval("nrow(field)"),
Rcpp::as<std::vector<std::vector<double>>>(R.parseEval("field")),
R.parseEval("colnames(field)"));
MSG("AI Update");
chem.getField().update(predictions_field);
double ai_end_t = MPI_Wtime();
R["ai_prediction_time"] = ai_end_t - ai_start_t;
}
if (params.copy_non_reactive_regions || params.ai) {
MSG("Set copied or predicted values for the workers");
R.parseEval(
"print(paste('Number of valid cells:', sum(validity_vector)))");
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
}
chem.simulate(dt);
/* AI surrogate iterative training*/
@ -441,10 +480,15 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
")), TMP_PROPS)"));
R.parseEval("predictors_retraining <- "
"get_invalid_values(predictors_scaled, validity_vector)");
"get_invalid_values(predictors_scaled, ai_validity_vector)");
R.parseEval("print(head(predictors_retraining))");
R.parseEval("targets <- targets[predictor_idx, ]");
R.parseEval("targets_retraining <- "
"get_invalid_values(targets[ai_surrogate_species_output], "
"validity_vector)");
"ai_validity_vector)");
R.parseEval("print(length(predictors_scaled$H))");
R.parseEval("print(length(ai_validity_vector))");
R.parseEval("targets_retraining <- preprocess(targets_retraining)");
std::vector<std::vector<float>> predictors_retraining =
@ -476,15 +520,12 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
std::cout << "results_buffer_size: " << elements_results_buffer
<< std::endl;
if (elements_design_buffer >= 20 * params.batch_size &&
if (elements_design_buffer >=
20 * params.batch_size && // TODO: change to 4 * grid_size
elements_results_buffer >= 20 * params.batch_size &&
ai_ctx->training_is_running == false) {
ai_ctx->data_semaphore_read.release();
} else if (ai_ctx->training_is_running == true) {
MSG("Training is currently running");
ai_ctx->data_semaphore_write.release();
} else {
MSG("Not enough data for retraining");
ai_ctx->data_semaphore_write.release();
}
@ -707,7 +748,8 @@ int main(int argc, char *argv[]) {
run_params.interp_bucket_entries,
run_params.interp_size,
run_params.interp_min_entries,
run_params.ai};
run_params.ai,
run_params.copy_non_reactive_regions};
chemistry.masterEnableSurrogates(surr_setup);

View File

@ -25,6 +25,7 @@
#include <atomic>
#include <cstdint>
#include <string>
#include <type_traits>
#include <vector>
#include <MetaParameter.hpp>
@ -83,9 +84,12 @@ struct RuntimeParameters {
bool ai = false;
bool disable_retraining = false;
static constexpr std::uint8_t AI_BACKEND_DEFAULT = 1;
std::uint8_t ai_backend = 1; // 1 - python, 2 - naa, 3 - cuda
std::uint8_t ai_backend = 1; // 1 - python, 2 - naa
bool train_only_invalid = true;
int batch_size = 1000;
static constexpr bool COPY_NON_REACTIVE_REGIONS = false;
bool copy_non_reactive_regions = COPY_NON_REACTIVE_REGIONS;
};
struct AIContext {