Merge branch 'naa-naaice' of git.gfz-potsdam.de:naaice/poet into naa-naaice

This commit is contained in:
Marco De Lucia 2025-12-09 14:39:28 +01:00
commit 80c51a14ae
9 changed files with 104 additions and 52 deletions

View File

@ -34,7 +34,7 @@ add_subdirectory(ext/tug EXCLUDE_FROM_ALL)
add_subdirectory(ext/iphreeqc EXCLUDE_FROM_ALL) add_subdirectory(ext/iphreeqc EXCLUDE_FROM_ALL)
# AI/NAA specific includes TODO: add option flags # AI/NAA specific includes TODO: add option flags
add_subdirectory(ext/ai-surrogate EXCLUDE_FROM_ALL) add_subdirectory(ext/ai-surrogate-poet EXCLUDE_FROM_ALL)

View File

@ -37,6 +37,7 @@ ai_surrogate_species_input = c("H", "O", "Ba", "Cl", "S", "Sr", "Barite_kin", "C
ai_surrogate_species_output = c("O", "Ba", "S", "Sr", "Barite_kin", "Celestite_kin") ai_surrogate_species_output = c("O", "Ba", "S", "Sr", "Barite_kin", "Celestite_kin")
threshold <- list(species = "Cl", value = 1E-10)
preprocess <- function(df) { preprocess <- function(df) {
if (!is.data.frame(df)) if (!is.data.frame(df))

@ -1 +1 @@
Subproject commit 112c8ff1a88f47a73909724e31227173fd50126a Subproject commit 2dd2b8881d6fe27b08a259d48ee8bca6188f049a

View File

@ -89,9 +89,9 @@ file(READ "${PROJECT_SOURCE_DIR}/R_lib/init_r_lib.R" R_INIT_LIB)
file(READ "${PROJECT_SOURCE_DIR}/R_lib/ai_surrogate_model.R" R_AI_SURROGATE_LIB) file(READ "${PROJECT_SOURCE_DIR}/R_lib/ai_surrogate_model.R" R_AI_SURROGATE_LIB)
configure_file(poet.hpp.in poet.hpp @ONLY) configure_file(poet.hpp.in poet.hpp @ONLY)
include_directories("${CMAKE_SOURCE_DIR}/ext/ai-surrogate-poet/include")
add_executable(poet poet.cpp) add_executable(poet poet.cpp)
target_link_libraries(poet PRIVATE POETLib MPI::MPI_C RRuntime CLI11::CLI11) target_link_libraries(poet PRIVATE POETLib MPI::MPI_C RRuntime CLI11::CLI11 ai naaice::middleware)
target_include_directories(poet PRIVATE "${CMAKE_CURRENT_BINARY_DIR}") target_include_directories(poet PRIVATE "${CMAKE_CURRENT_BINARY_DIR}")
add_executable(poet_init initializer.cpp) add_executable(poet_init initializer.cpp)

View File

@ -89,6 +89,7 @@ public:
std::uint32_t interp_size_mb; std::uint32_t interp_size_mb;
std::uint32_t interp_min_entries; std::uint32_t interp_min_entries;
bool ai_surrogate_enabled; bool ai_surrogate_enabled;
bool copy_non_reactive;
}; };
void masterEnableSurrogates(const SurrogateSetup &setup) { void masterEnableSurrogates(const SurrogateSetup &setup) {
@ -99,6 +100,7 @@ public:
this->dht_enabled = setup.dht_enabled; this->dht_enabled = setup.dht_enabled;
this->interp_enabled = setup.interp_enabled; this->interp_enabled = setup.interp_enabled;
this->ai_surrogate_enabled = setup.ai_surrogate_enabled; this->ai_surrogate_enabled = setup.ai_surrogate_enabled;
this->copy_non_reactive = setup.copy_non_reactive;
this->base_totals = setup.base_totals; this->base_totals = setup.base_totals;
@ -372,6 +374,7 @@ protected:
std::unique_ptr<poet::InterpolationModule> interp; std::unique_ptr<poet::InterpolationModule> interp;
bool ai_surrogate_enabled{false}; bool ai_surrogate_enabled{false};
bool copy_non_reactive{false};
static constexpr uint32_t BUFFER_OFFSET = 5; static constexpr uint32_t BUFFER_OFFSET = 5;

View File

@ -262,10 +262,11 @@ inline void poet::ChemistryModule::MasterSendPkgs(
// current time of simulation (age) in seconds // current time of simulation (age) in seconds
send_buffer[end_of_wp + 3] = this->simtime; send_buffer[end_of_wp + 3] = this->simtime;
// current work package start location in field // current work package start location in field
uint32_t wp_start_index = std::accumulate(wp_sizes_vector.begin(), std::next(wp_sizes_vector.begin(), count_pkgs), 0); uint32_t wp_start_index =
std::accumulate(wp_sizes_vector.begin(),
std::next(wp_sizes_vector.begin(), count_pkgs), 0);
send_buffer[end_of_wp + 4] = wp_start_index; send_buffer[end_of_wp + 4] = wp_start_index;
/* ATTENTION Worker p has rank p+1 */ /* ATTENTION Worker p has rank p+1 */
// MPI_Send(send_buffer, end_of_wp + BUFFER_OFFSET, MPI_DOUBLE, p + 1, // MPI_Send(send_buffer, end_of_wp + BUFFER_OFFSET, MPI_DOUBLE, p + 1,
// LOOP_WORK, this->group_comm); // LOOP_WORK, this->group_comm);
@ -373,14 +374,15 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
const std::vector<uint32_t> wp_sizes_vector = const std::vector<uint32_t> wp_sizes_vector =
CalculateWPSizesVector(this->n_cells, this->wp_size); CalculateWPSizesVector(this->n_cells, this->wp_size);
if (this->ai_surrogate_enabled) { if (this->ai_surrogate_enabled || this->copy_non_reactive) {
ftype = CHEM_AI_BCAST_VALIDITY; ftype = CHEM_AI_BCAST_VALIDITY;
PropagateFunctionType(ftype); PropagateFunctionType(ftype);
this->ai_surrogate_validity_vector = shuffleVector(this->ai_surrogate_validity_vector, this->ai_surrogate_validity_vector =
this->n_cells, shuffleVector(this->ai_surrogate_validity_vector, this->n_cells,
wp_sizes_vector.size()); wp_sizes_vector.size());
ChemBCast(&this->ai_surrogate_validity_vector.front(), this->n_cells, MPI_INT); ChemBCast(&this->ai_surrogate_validity_vector.front(), this->n_cells,
} MPI_INT);
}
ftype = CHEM_WORK_LOOP; ftype = CHEM_WORK_LOOP;
PropagateFunctionType(ftype); PropagateFunctionType(ftype);

View File

@ -46,7 +46,7 @@ void poet::ChemistryModule::WorkerLoop() {
switch (func_type) { switch (func_type) {
case CHEM_FIELD_INIT: { case CHEM_FIELD_INIT: {
ChemBCast(&this->prop_count, 1, MPI_UINT32_T); ChemBCast(&this->prop_count, 1, MPI_UINT32_T);
if (this->ai_surrogate_enabled) { if (this->ai_surrogate_enabled || this->copy_non_reactive) {
this->ai_surrogate_validity_vector.resize( this->ai_surrogate_validity_vector.resize(
this->n_cells); // resize statt reserve? this->n_cells); // resize statt reserve?
} }
@ -179,7 +179,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
interp->tryInterpolation(s_curr_wp); interp->tryInterpolation(s_curr_wp);
} }
if (this->ai_surrogate_enabled) { if (this->ai_surrogate_enabled || this->copy_non_reactive) {
// Map valid predictions from the ai surrogate in the workpackage // Map valid predictions from the ai surrogate in the workpackage
for (int i = 0; i < s_curr_wp.size; i++) { for (int i = 0; i < s_curr_wp.size; i++) {
if (this->ai_surrogate_validity_vector[wp_start_index + i] == 1) { if (this->ai_surrogate_validity_vector[wp_start_index + i] == 1) {

View File

@ -36,6 +36,7 @@
#include <Rcpp/vector/instantiation.h> #include <Rcpp/vector/instantiation.h>
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <cstddef>
#include <cstdint> #include <cstdint>
#include <cstdlib> #include <cstdlib>
#include <memory> #include <memory>
@ -162,6 +163,9 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
->check(CLI::PositiveNumber) ->check(CLI::PositiveNumber)
->default_val(RuntimeParameters::AI_BACKEND_DEFAULT); ->default_val(RuntimeParameters::AI_BACKEND_DEFAULT);
app.add_flag("-c,--copy-non-reactive", params.copy_non_reactive_regions,
"Copy non-reactive regions instead of computing them");
app.add_flag("--rds", params.as_rds, app.add_flag("--rds", params.as_rds,
"Save output as .rds file instead of default .qs2"); "Save output as .rds file instead of default .qs2");
@ -322,11 +326,12 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
// initialzie training backens only if retraining is desired // initialzie training backens only if retraining is desired
if (params.ai_backend == PYTHON_BACKEND) { if (params.ai_backend == PYTHON_BACKEND) {
MSG("AI Surrogate with Python/keras backend enabled.") MSG("AI Surrogate with Python/keras backend enabled.")
// auto model = Python<ai_type_t>(); ai_ctx->training_backend =
std::make_unique<PythonBackend<ai_type_t>>(4 * params.batch_size);
} else if (params.ai_backend == NAA_BACKEND) { } else if (params.ai_backend == NAA_BACKEND) {
MSG("AI Surrogate with NAA backend enabled.") MSG("AI Surrogate with NAA backend enabled.")
ai_ctx->training_backend = ai_ctx->training_backend =
std::make_unique<NAABackend<ai_type_t>>(20 * params.batch_size); std::make_unique<NAABackend<ai_type_t>>(4 * params.batch_size);
} }
if (!params.disable_retraining) { if (!params.disable_retraining) {
@ -356,27 +361,49 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
/* run transport */ /* run transport */
diffusion.simulate(dt); diffusion.simulate(dt);
chem.getField().update(diffusion.getField()); if (params.ai || params.copy_non_reactive_regions) {
chem.getField().update(diffusion.getField());
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
R.parseEval(
std::string("field <- setNames(data.frame(matrix(TMP, nrow=" +
std::to_string(chem.getField().GetRequestedVecSize()) +
")), TMP_PROPS)"));
R.parseEval("validity_vector <- rep(FALSE, nrow(field))");
if (params.copy_non_reactive_regions) {
R.parseEval("validity_vector <- field$Cl < 1e-14");
}
}
// MSG("Chemistry start"); // MSG("Chemistry start");
if (params.ai) { if (params.ai) {
double ai_start_t = MPI_Wtime(); double ai_start_t = MPI_Wtime();
// Save current values from the tug field as predictor for the ai step
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
R.parseEval(
std::string("predictors <- setNames(data.frame(matrix(TMP, nrow=" +
std::to_string(chem.getField().GetRequestedVecSize()) +
")), TMP_PROPS)"));
// deep copy field
R.parseEval("predictors <- data.frame(field)");
// get only ai related species
R.parseEval("predictors <- predictors[ai_surrogate_species_input]"); R.parseEval("predictors <- predictors[ai_surrogate_species_input]");
// remove already copied values
R.parseEval("predictors <- predictors[!validity_vector,]");
R.parseEval(
"print(paste('Length of predictors:', length(predictors$H)))");
// store row names of predictors
R.parseEval("predictor_idx <- row.names(predictors)");
R.parseEval("print(head(predictors))");
R.parseEval("predictors_scaled <- preprocess(predictors)"); R.parseEval("predictors_scaled <- preprocess(predictors)");
std::vector<std::vector<float>> predictors_scaled = std::vector<std::vector<float>> predictors_scaled =
R["predictors_scaled"]; R["predictors_scaled"];
// FIXME: double/float conversion std::vector<float> predictions_scaled =
std::vector<float> predictions_scaled = ai_ctx->model.predict( ai_ctx->model.predict(predictors_scaled, params.batch_size,
predictors_scaled, params.batch_size, ai_ctx->model_semaphore); ai_ctx->model_semaphore); // features per cell
int n_samples = R.parseEval("nrow(predictors)"); int n_samples = R.parseEval("nrow(predictors)");
int n_output_features = ai_ctx->model.weight_matrices.back().cols(); int n_output_features = ai_ctx->model.weight_matrices.back().cols();
@ -396,38 +423,50 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
MSG("AI Validation"); MSG("AI Validation");
// FIXME: (mass balance plausible?) R.parseEval("ai_validity_vector <- validate_predictions(predictors, "
R.parseEval("validity_vector <- validate_predictions(predictors, "
"predictions) "); "predictions) ");
R.parseEval("print(length(predictor_idx))");
R.parseEval("print(length(ai_validity_vector))");
// get only indices where prediction was valid
R.parseEval("predictor_idx <- predictor_idx[ai_validity_vector]");
// set in global validity vector all elements to true, where prediction
// was possible
R.parseEval("validity_vector[predictor_idx] <- TRUE");
R.parseEval("print(head(validity_vector))"); R.parseEval("print(head(validity_vector))");
MSG("AI Marking accepted");
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
MSG("AI TempField"); MSG("AI TempField");
R.parseEval("print(ai_surrogate_species_output)"); // maybe row.names was overwritten by function calls ??
// R.parseEval("print(head(predictors))"); R.parseEval("row.names(predictions) <- row.names(predictors)");
std::vector<std::vector<double>> RTempField = R.parseEval( // subset predictions to ai_validity_vector == TRUE
"set_valid_predictions(predictors[ai_surrogate_species_output],\ R.parseEval("predictions <- predictions[ai_validity_vector,]");
predictions,\ // merge predicted values into field stored in R
validity_vector)"); R.parseEval("field[row.names(predictions),ai_surrogate_species_output] "
"<- predictions");
MSG("AI Set Field"); MSG("AI Set Field");
Field predictions_field = Field( Field predictions_field = Field(
R.parseEval("nrow(predictors)"), RTempField, R.parseEval("nrow(field)"),
R.parseEval( Rcpp::as<std::vector<std::vector<double>>>(R.parseEval("field")),
"colnames(predictors[ai_surrogate_species_output])")); // FIXME: R.parseEval("colnames(field)"));
// is this
// correct?
MSG("AI Update");
chem.getField().update(predictions_field); chem.getField().update(predictions_field);
double ai_end_t = MPI_Wtime(); double ai_end_t = MPI_Wtime();
R["ai_prediction_time"] = ai_end_t - ai_start_t; R["ai_prediction_time"] = ai_end_t - ai_start_t;
} }
if (params.copy_non_reactive_regions || params.ai) {
MSG("Set copied or predicted values for the workers");
R.parseEval(
"print(paste('Number of valid cells:', sum(validity_vector)))");
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
}
chem.simulate(dt); chem.simulate(dt);
/* AI surrogate iterative training*/ /* AI surrogate iterative training*/
@ -441,10 +480,15 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
")), TMP_PROPS)")); ")), TMP_PROPS)"));
R.parseEval("predictors_retraining <- " R.parseEval("predictors_retraining <- "
"get_invalid_values(predictors_scaled, validity_vector)"); "get_invalid_values(predictors_scaled, ai_validity_vector)");
R.parseEval("print(head(predictors_retraining))");
R.parseEval("targets <- targets[predictor_idx, ]");
R.parseEval("targets_retraining <- " R.parseEval("targets_retraining <- "
"get_invalid_values(targets[ai_surrogate_species_output], " "get_invalid_values(targets[ai_surrogate_species_output], "
"validity_vector)"); "ai_validity_vector)");
R.parseEval("print(length(predictors_scaled$H))");
R.parseEval("print(length(ai_validity_vector))");
R.parseEval("targets_retraining <- preprocess(targets_retraining)"); R.parseEval("targets_retraining <- preprocess(targets_retraining)");
std::vector<std::vector<float>> predictors_retraining = std::vector<std::vector<float>> predictors_retraining =
@ -476,15 +520,12 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
std::cout << "results_buffer_size: " << elements_results_buffer std::cout << "results_buffer_size: " << elements_results_buffer
<< std::endl; << std::endl;
if (elements_design_buffer >= 20 * params.batch_size && if (elements_design_buffer >=
20 * params.batch_size && // TODO: change to 4 * grid_size
elements_results_buffer >= 20 * params.batch_size && elements_results_buffer >= 20 * params.batch_size &&
ai_ctx->training_is_running == false) { ai_ctx->training_is_running == false) {
ai_ctx->data_semaphore_read.release(); ai_ctx->data_semaphore_read.release();
} else if (ai_ctx->training_is_running == true) {
MSG("Training is currently running");
ai_ctx->data_semaphore_write.release();
} else { } else {
MSG("Not enough data for retraining");
ai_ctx->data_semaphore_write.release(); ai_ctx->data_semaphore_write.release();
} }
@ -707,7 +748,8 @@ int main(int argc, char *argv[]) {
run_params.interp_bucket_entries, run_params.interp_bucket_entries,
run_params.interp_size, run_params.interp_size,
run_params.interp_min_entries, run_params.interp_min_entries,
run_params.ai}; run_params.ai,
run_params.copy_non_reactive_regions};
chemistry.masterEnableSurrogates(surr_setup); chemistry.masterEnableSurrogates(surr_setup);

View File

@ -25,6 +25,7 @@
#include <atomic> #include <atomic>
#include <cstdint> #include <cstdint>
#include <string> #include <string>
#include <type_traits>
#include <vector> #include <vector>
#include <MetaParameter.hpp> #include <MetaParameter.hpp>
@ -83,9 +84,12 @@ struct RuntimeParameters {
bool ai = false; bool ai = false;
bool disable_retraining = false; bool disable_retraining = false;
static constexpr std::uint8_t AI_BACKEND_DEFAULT = 1; static constexpr std::uint8_t AI_BACKEND_DEFAULT = 1;
std::uint8_t ai_backend = 1; // 1 - python, 2 - naa, 3 - cuda std::uint8_t ai_backend = 1; // 1 - python, 2 - naa
bool train_only_invalid = true; bool train_only_invalid = true;
int batch_size = 1000; int batch_size = 1000;
static constexpr bool COPY_NON_REACTIVE_REGIONS = false;
bool copy_non_reactive_regions = COPY_NON_REACTIVE_REGIONS;
}; };
struct AIContext { struct AIContext {