add naa-communication header

This commit is contained in:
Hannes Signer 2024-12-06 10:55:35 +01:00
parent 5db25d63ba
commit 13ad41d302

View File

@ -44,6 +44,11 @@
#include <CLI/CLI.hpp> #include <CLI/CLI.hpp>
#include <poet.hpp> #include <poet.hpp>
#include <vector> #include <vector>
#include <stdio.h>
extern "C"{
#include <naaice.h>
}
using namespace std; using namespace std;
using namespace poet; using namespace poet;
@ -311,6 +316,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
params.use_clustering); params.use_clustering);
if (!params.disable_training) { if (!params.disable_training) {
MSG("AI: Initialize training thread"); MSG("AI: Initialize training thread");
// TODO add naa_handle as optional parameter which is NULL per default
Python_Keras_training_thread(&Eigen_model, &Eigen_model_reactive, Python_Keras_training_thread(&Eigen_model, &Eigen_model_reactive,
&Eigen_model_mutex, &training_data_buffer, &Eigen_model_mutex, &training_data_buffer,
&training_data_buffer_mutex, &training_data_buffer_mutex,
@ -598,7 +604,7 @@ int main(int argc, char *argv[]) {
// Threadsafe MPI is necessary for the AI surrogate // Threadsafe MPI is necessary for the AI surrogate
// training thread // training thread
int provided; int provided;
int required = MPI_THREAD_FUNNELED; int required = MPI_THREAD_FUNNELED; // the application is multithreaded but MPI calls are only made from the main thread
MPI_Init_thread(&argc, &argv, required, &provided); MPI_Init_thread(&argc, &argv, required, &provided);
{ {
@ -651,6 +657,7 @@ int main(int argc, char *argv[]) {
run_params.interp_size, run_params.interp_size,
run_params.interp_min_entries, run_params.interp_min_entries,
run_params.use_ai_surrogate}; run_params.use_ai_surrogate};
// TODO add option for naa training
chemistry.masterEnableSurrogates(surr_setup); chemistry.masterEnableSurrogates(surr_setup);
@ -697,6 +704,7 @@ int main(int argc, char *argv[]) {
variable of the same name in one of the the R input scripts)*/ variable of the same name in one of the the R input scripts)*/
run_params.training_data_size = init_list.getDiffusionInit().n_rows * run_params.training_data_size = init_list.getDiffusionInit().n_rows *
init_list.getDiffusionInit().n_cols; // Default value is number of cells in field init_list.getDiffusionInit().n_cols; // Default value is number of cells in field
// ? error handling for all if statements
if (Rcpp::as<bool>(R.parseEval("exists(\"batch_size\")"))) { if (Rcpp::as<bool>(R.parseEval("exists(\"batch_size\")"))) {
run_params.batch_size = R["batch_size"]; run_params.batch_size = R["batch_size"];
} }