mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-13 03:18:23 +01:00
Merge branch 'naa-naaice' of git.gfz-potsdam.de:naaice/poet into naa-naaice
This commit is contained in:
commit
80c51a14ae
@ -34,7 +34,7 @@ add_subdirectory(ext/tug EXCLUDE_FROM_ALL)
|
||||
add_subdirectory(ext/iphreeqc EXCLUDE_FROM_ALL)
|
||||
|
||||
# AI/NAA specific includes TODO: add option flags
|
||||
add_subdirectory(ext/ai-surrogate EXCLUDE_FROM_ALL)
|
||||
add_subdirectory(ext/ai-surrogate-poet EXCLUDE_FROM_ALL)
|
||||
|
||||
|
||||
|
||||
|
||||
@ -37,6 +37,7 @@ ai_surrogate_species_input = c("H", "O", "Ba", "Cl", "S", "Sr", "Barite_kin", "C
|
||||
ai_surrogate_species_output = c("O", "Ba", "S", "Sr", "Barite_kin", "Celestite_kin")
|
||||
|
||||
|
||||
threshold <- list(species = "Cl", value = 1E-10)
|
||||
|
||||
preprocess <- function(df) {
|
||||
if (!is.data.frame(df))
|
||||
|
||||
@ -1 +1 @@
|
||||
Subproject commit 112c8ff1a88f47a73909724e31227173fd50126a
|
||||
Subproject commit 2dd2b8881d6fe27b08a259d48ee8bca6188f049a
|
||||
@ -89,9 +89,9 @@ file(READ "${PROJECT_SOURCE_DIR}/R_lib/init_r_lib.R" R_INIT_LIB)
|
||||
file(READ "${PROJECT_SOURCE_DIR}/R_lib/ai_surrogate_model.R" R_AI_SURROGATE_LIB)
|
||||
|
||||
configure_file(poet.hpp.in poet.hpp @ONLY)
|
||||
|
||||
include_directories("${CMAKE_SOURCE_DIR}/ext/ai-surrogate-poet/include")
|
||||
add_executable(poet poet.cpp)
|
||||
target_link_libraries(poet PRIVATE POETLib MPI::MPI_C RRuntime CLI11::CLI11)
|
||||
target_link_libraries(poet PRIVATE POETLib MPI::MPI_C RRuntime CLI11::CLI11 ai naaice::middleware)
|
||||
target_include_directories(poet PRIVATE "${CMAKE_CURRENT_BINARY_DIR}")
|
||||
|
||||
add_executable(poet_init initializer.cpp)
|
||||
|
||||
@ -89,6 +89,7 @@ public:
|
||||
std::uint32_t interp_size_mb;
|
||||
std::uint32_t interp_min_entries;
|
||||
bool ai_surrogate_enabled;
|
||||
bool copy_non_reactive;
|
||||
};
|
||||
|
||||
void masterEnableSurrogates(const SurrogateSetup &setup) {
|
||||
@ -99,6 +100,7 @@ public:
|
||||
this->dht_enabled = setup.dht_enabled;
|
||||
this->interp_enabled = setup.interp_enabled;
|
||||
this->ai_surrogate_enabled = setup.ai_surrogate_enabled;
|
||||
this->copy_non_reactive = setup.copy_non_reactive;
|
||||
|
||||
this->base_totals = setup.base_totals;
|
||||
|
||||
@ -372,6 +374,7 @@ protected:
|
||||
std::unique_ptr<poet::InterpolationModule> interp;
|
||||
|
||||
bool ai_surrogate_enabled{false};
|
||||
bool copy_non_reactive{false};
|
||||
|
||||
static constexpr uint32_t BUFFER_OFFSET = 5;
|
||||
|
||||
|
||||
@ -262,10 +262,11 @@ inline void poet::ChemistryModule::MasterSendPkgs(
|
||||
// current time of simulation (age) in seconds
|
||||
send_buffer[end_of_wp + 3] = this->simtime;
|
||||
// current work package start location in field
|
||||
uint32_t wp_start_index = std::accumulate(wp_sizes_vector.begin(), std::next(wp_sizes_vector.begin(), count_pkgs), 0);
|
||||
uint32_t wp_start_index =
|
||||
std::accumulate(wp_sizes_vector.begin(),
|
||||
std::next(wp_sizes_vector.begin(), count_pkgs), 0);
|
||||
send_buffer[end_of_wp + 4] = wp_start_index;
|
||||
|
||||
|
||||
/* ATTENTION Worker p has rank p+1 */
|
||||
// MPI_Send(send_buffer, end_of_wp + BUFFER_OFFSET, MPI_DOUBLE, p + 1,
|
||||
// LOOP_WORK, this->group_comm);
|
||||
@ -373,14 +374,15 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
||||
const std::vector<uint32_t> wp_sizes_vector =
|
||||
CalculateWPSizesVector(this->n_cells, this->wp_size);
|
||||
|
||||
if (this->ai_surrogate_enabled) {
|
||||
if (this->ai_surrogate_enabled || this->copy_non_reactive) {
|
||||
ftype = CHEM_AI_BCAST_VALIDITY;
|
||||
PropagateFunctionType(ftype);
|
||||
this->ai_surrogate_validity_vector = shuffleVector(this->ai_surrogate_validity_vector,
|
||||
this->n_cells,
|
||||
wp_sizes_vector.size());
|
||||
ChemBCast(&this->ai_surrogate_validity_vector.front(), this->n_cells, MPI_INT);
|
||||
}
|
||||
this->ai_surrogate_validity_vector =
|
||||
shuffleVector(this->ai_surrogate_validity_vector, this->n_cells,
|
||||
wp_sizes_vector.size());
|
||||
ChemBCast(&this->ai_surrogate_validity_vector.front(), this->n_cells,
|
||||
MPI_INT);
|
||||
}
|
||||
|
||||
ftype = CHEM_WORK_LOOP;
|
||||
PropagateFunctionType(ftype);
|
||||
|
||||
@ -46,7 +46,7 @@ void poet::ChemistryModule::WorkerLoop() {
|
||||
switch (func_type) {
|
||||
case CHEM_FIELD_INIT: {
|
||||
ChemBCast(&this->prop_count, 1, MPI_UINT32_T);
|
||||
if (this->ai_surrogate_enabled) {
|
||||
if (this->ai_surrogate_enabled || this->copy_non_reactive) {
|
||||
this->ai_surrogate_validity_vector.resize(
|
||||
this->n_cells); // resize statt reserve?
|
||||
}
|
||||
@ -179,7 +179,7 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
interp->tryInterpolation(s_curr_wp);
|
||||
}
|
||||
|
||||
if (this->ai_surrogate_enabled) {
|
||||
if (this->ai_surrogate_enabled || this->copy_non_reactive) {
|
||||
// Map valid predictions from the ai surrogate in the workpackage
|
||||
for (int i = 0; i < s_curr_wp.size; i++) {
|
||||
if (this->ai_surrogate_validity_vector[wp_start_index + i] == 1) {
|
||||
|
||||
116
src/poet.cpp
116
src/poet.cpp
@ -36,6 +36,7 @@
|
||||
#include <Rcpp/vector/instantiation.h>
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
@ -162,6 +163,9 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) {
|
||||
->check(CLI::PositiveNumber)
|
||||
->default_val(RuntimeParameters::AI_BACKEND_DEFAULT);
|
||||
|
||||
app.add_flag("-c,--copy-non-reactive", params.copy_non_reactive_regions,
|
||||
"Copy non-reactive regions instead of computing them");
|
||||
|
||||
app.add_flag("--rds", params.as_rds,
|
||||
"Save output as .rds file instead of default .qs2");
|
||||
|
||||
@ -322,11 +326,12 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
// initialzie training backens only if retraining is desired
|
||||
if (params.ai_backend == PYTHON_BACKEND) {
|
||||
MSG("AI Surrogate with Python/keras backend enabled.")
|
||||
// auto model = Python<ai_type_t>();
|
||||
ai_ctx->training_backend =
|
||||
std::make_unique<PythonBackend<ai_type_t>>(4 * params.batch_size);
|
||||
} else if (params.ai_backend == NAA_BACKEND) {
|
||||
MSG("AI Surrogate with NAA backend enabled.")
|
||||
ai_ctx->training_backend =
|
||||
std::make_unique<NAABackend<ai_type_t>>(20 * params.batch_size);
|
||||
std::make_unique<NAABackend<ai_type_t>>(4 * params.batch_size);
|
||||
}
|
||||
|
||||
if (!params.disable_retraining) {
|
||||
@ -356,27 +361,49 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
/* run transport */
|
||||
diffusion.simulate(dt);
|
||||
|
||||
chem.getField().update(diffusion.getField());
|
||||
if (params.ai || params.copy_non_reactive_regions) {
|
||||
|
||||
chem.getField().update(diffusion.getField());
|
||||
|
||||
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
|
||||
R.parseEval(
|
||||
std::string("field <- setNames(data.frame(matrix(TMP, nrow=" +
|
||||
std::to_string(chem.getField().GetRequestedVecSize()) +
|
||||
")), TMP_PROPS)"));
|
||||
|
||||
R.parseEval("validity_vector <- rep(FALSE, nrow(field))");
|
||||
|
||||
if (params.copy_non_reactive_regions) {
|
||||
R.parseEval("validity_vector <- field$Cl < 1e-14");
|
||||
}
|
||||
}
|
||||
|
||||
// MSG("Chemistry start");
|
||||
if (params.ai) {
|
||||
double ai_start_t = MPI_Wtime();
|
||||
// Save current values from the tug field as predictor for the ai step
|
||||
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
|
||||
R.parseEval(
|
||||
std::string("predictors <- setNames(data.frame(matrix(TMP, nrow=" +
|
||||
std::to_string(chem.getField().GetRequestedVecSize()) +
|
||||
")), TMP_PROPS)"));
|
||||
|
||||
// deep copy field
|
||||
R.parseEval("predictors <- data.frame(field)");
|
||||
// get only ai related species
|
||||
R.parseEval("predictors <- predictors[ai_surrogate_species_input]");
|
||||
|
||||
// remove already copied values
|
||||
R.parseEval("predictors <- predictors[!validity_vector,]");
|
||||
|
||||
R.parseEval(
|
||||
"print(paste('Length of predictors:', length(predictors$H)))");
|
||||
|
||||
// store row names of predictors
|
||||
R.parseEval("predictor_idx <- row.names(predictors)");
|
||||
|
||||
R.parseEval("print(head(predictors))");
|
||||
R.parseEval("predictors_scaled <- preprocess(predictors)");
|
||||
std::vector<std::vector<float>> predictors_scaled =
|
||||
R["predictors_scaled"];
|
||||
|
||||
// FIXME: double/float conversion
|
||||
std::vector<float> predictions_scaled = ai_ctx->model.predict(
|
||||
predictors_scaled, params.batch_size, ai_ctx->model_semaphore);
|
||||
std::vector<float> predictions_scaled =
|
||||
ai_ctx->model.predict(predictors_scaled, params.batch_size,
|
||||
ai_ctx->model_semaphore); // features per cell
|
||||
|
||||
int n_samples = R.parseEval("nrow(predictors)");
|
||||
int n_output_features = ai_ctx->model.weight_matrices.back().cols();
|
||||
@ -396,38 +423,50 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
|
||||
MSG("AI Validation");
|
||||
|
||||
// FIXME: (mass balance plausible?)
|
||||
R.parseEval("validity_vector <- validate_predictions(predictors, "
|
||||
R.parseEval("ai_validity_vector <- validate_predictions(predictors, "
|
||||
"predictions) ");
|
||||
|
||||
R.parseEval("print(length(predictor_idx))");
|
||||
R.parseEval("print(length(ai_validity_vector))");
|
||||
|
||||
// get only indices where prediction was valid
|
||||
R.parseEval("predictor_idx <- predictor_idx[ai_validity_vector]");
|
||||
|
||||
// set in global validity vector all elements to true, where prediction
|
||||
// was possible
|
||||
R.parseEval("validity_vector[predictor_idx] <- TRUE");
|
||||
|
||||
R.parseEval("print(head(validity_vector))");
|
||||
|
||||
MSG("AI Marking accepted");
|
||||
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
|
||||
|
||||
MSG("AI TempField");
|
||||
R.parseEval("print(ai_surrogate_species_output)");
|
||||
// R.parseEval("print(head(predictors))");
|
||||
std::vector<std::vector<double>> RTempField = R.parseEval(
|
||||
"set_valid_predictions(predictors[ai_surrogate_species_output],\
|
||||
predictions,\
|
||||
validity_vector)");
|
||||
// maybe row.names was overwritten by function calls ??
|
||||
R.parseEval("row.names(predictions) <- row.names(predictors)");
|
||||
// subset predictions to ai_validity_vector == TRUE
|
||||
R.parseEval("predictions <- predictions[ai_validity_vector,]");
|
||||
// merge predicted values into field stored in R
|
||||
R.parseEval("field[row.names(predictions),ai_surrogate_species_output] "
|
||||
"<- predictions");
|
||||
|
||||
MSG("AI Set Field");
|
||||
Field predictions_field = Field(
|
||||
R.parseEval("nrow(predictors)"), RTempField,
|
||||
R.parseEval(
|
||||
"colnames(predictors[ai_surrogate_species_output])")); // FIXME:
|
||||
// is this
|
||||
// correct?
|
||||
R.parseEval("nrow(field)"),
|
||||
Rcpp::as<std::vector<std::vector<double>>>(R.parseEval("field")),
|
||||
R.parseEval("colnames(field)"));
|
||||
|
||||
MSG("AI Update");
|
||||
chem.getField().update(predictions_field);
|
||||
|
||||
double ai_end_t = MPI_Wtime();
|
||||
R["ai_prediction_time"] = ai_end_t - ai_start_t;
|
||||
}
|
||||
|
||||
if (params.copy_non_reactive_regions || params.ai) {
|
||||
MSG("Set copied or predicted values for the workers");
|
||||
|
||||
R.parseEval(
|
||||
"print(paste('Number of valid cells:', sum(validity_vector)))");
|
||||
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
|
||||
}
|
||||
|
||||
chem.simulate(dt);
|
||||
|
||||
/* AI surrogate iterative training*/
|
||||
@ -441,10 +480,15 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
")), TMP_PROPS)"));
|
||||
|
||||
R.parseEval("predictors_retraining <- "
|
||||
"get_invalid_values(predictors_scaled, validity_vector)");
|
||||
"get_invalid_values(predictors_scaled, ai_validity_vector)");
|
||||
R.parseEval("print(head(predictors_retraining))");
|
||||
R.parseEval("targets <- targets[predictor_idx, ]");
|
||||
R.parseEval("targets_retraining <- "
|
||||
"get_invalid_values(targets[ai_surrogate_species_output], "
|
||||
"validity_vector)");
|
||||
"ai_validity_vector)");
|
||||
R.parseEval("print(length(predictors_scaled$H))");
|
||||
R.parseEval("print(length(ai_validity_vector))");
|
||||
|
||||
R.parseEval("targets_retraining <- preprocess(targets_retraining)");
|
||||
|
||||
std::vector<std::vector<float>> predictors_retraining =
|
||||
@ -476,15 +520,12 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
||||
std::cout << "results_buffer_size: " << elements_results_buffer
|
||||
<< std::endl;
|
||||
|
||||
if (elements_design_buffer >= 20 * params.batch_size &&
|
||||
if (elements_design_buffer >=
|
||||
20 * params.batch_size && // TODO: change to 4 * grid_size
|
||||
elements_results_buffer >= 20 * params.batch_size &&
|
||||
ai_ctx->training_is_running == false) {
|
||||
ai_ctx->data_semaphore_read.release();
|
||||
} else if (ai_ctx->training_is_running == true) {
|
||||
MSG("Training is currently running");
|
||||
ai_ctx->data_semaphore_write.release();
|
||||
} else {
|
||||
MSG("Not enough data for retraining");
|
||||
ai_ctx->data_semaphore_write.release();
|
||||
}
|
||||
|
||||
@ -707,7 +748,8 @@ int main(int argc, char *argv[]) {
|
||||
run_params.interp_bucket_entries,
|
||||
run_params.interp_size,
|
||||
run_params.interp_min_entries,
|
||||
run_params.ai};
|
||||
run_params.ai,
|
||||
run_params.copy_non_reactive_regions};
|
||||
|
||||
chemistry.masterEnableSurrogates(surr_setup);
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@
|
||||
#include <atomic>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include <MetaParameter.hpp>
|
||||
@ -83,9 +84,12 @@ struct RuntimeParameters {
|
||||
bool ai = false;
|
||||
bool disable_retraining = false;
|
||||
static constexpr std::uint8_t AI_BACKEND_DEFAULT = 1;
|
||||
std::uint8_t ai_backend = 1; // 1 - python, 2 - naa, 3 - cuda
|
||||
std::uint8_t ai_backend = 1; // 1 - python, 2 - naa
|
||||
bool train_only_invalid = true;
|
||||
int batch_size = 1000;
|
||||
|
||||
static constexpr bool COPY_NON_REACTIVE_REGIONS = false;
|
||||
bool copy_non_reactive_regions = COPY_NON_REACTIVE_REGIONS;
|
||||
};
|
||||
|
||||
struct AIContext {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user