poet/src/poet.hpp.in
2025-12-10 19:52:00 +01:00

112 lines
3.7 KiB
C++

/*
** Copyright (C) 2018-2021 Alexander Lindemann, Max Luebke (University of
** Potsdam)
**
** Copyright (C) 2018-2023 Marco De Lucia, Max Luebke (GFZ Potsdam)
**
** Copyright (C) 2023-2024 Max Luebke (University of Potsdam)
**
** POET is free software; you can redistribute it and/or modify it under the
** terms of the GNU General Public License as published by the Free Software
** Foundation; either version 2 of the License, or (at your option) any later
** version.
**
** POET is distributed in the hope that it will be useful, but WITHOUT ANY
** WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
** A PARTICULAR PURPOSE. See the GNU General Public License for more details.
**
** You should have received a copy of the GNU General Public License along with
** this program; if not, write to the Free Software Foundation, Inc., 51
** Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
#pragma once
#include <atomic>
#include <cstdint>
#include <string>
#include <type_traits>
#include <vector>
#include <MetaParameter.hpp>
#include <Model.hpp>
#include <Standardizer.hpp>
#include <TrainingBackend.hpp>
#include <TrainingData.hpp>
#include <Rcpp.h>
using ai_type_t = float;
static const char *poet_version = "@POET_VERSION@";
// using the Raw string literal to avoid escaping the quotes
static const inline std::string kin_r_library = R"(@R_KIN_LIB@)";
static const inline std::string init_r_library = R"(@R_INIT_LIB@)";
static const inline std::string ai_surrogate_r_library =
R"(@R_AI_SURROGATE_LIB@)";
static const inline std::string r_runtime_parameters = "mysetup";
enum BACKEND_TYPE { PYTHON_BACKEND = 1, NAA_BACKEND, CUDA_BACKEND };
struct RuntimeParameters {
std::string out_dir;
std::vector<double> timesteps;
Rcpp::List init_params;
// MDL added to accomodate for qs::qsave/qread
bool as_rds = false;
bool as_qs = false;
std::string out_ext;
bool print_progress = false;
static constexpr std::uint32_t WORK_PACKAGE_SIZE_DEFAULT = 32;
std::uint32_t work_package_size = WORK_PACKAGE_SIZE_DEFAULT;
bool use_dht = false;
static constexpr std::uint32_t DHT_SIZE_DEFAULT = 1.5E3;
std::uint32_t dht_size = DHT_SIZE_DEFAULT;
static constexpr std::uint8_t DHT_SNAPS_DEFAULT = 0;
std::uint8_t dht_snaps = DHT_SNAPS_DEFAULT;
bool use_interp = false;
static constexpr std::uint32_t INTERP_SIZE_DEFAULT = 100;
std::uint32_t interp_size = INTERP_SIZE_DEFAULT;
static constexpr std::uint32_t INTERP_MIN_ENTRIES_DEFAULT = 5;
std::uint32_t interp_min_entries = INTERP_MIN_ENTRIES_DEFAULT;
static constexpr std::uint32_t INTERP_BUCKET_ENTRIES_DEFAULT = 20;
std::uint32_t interp_bucket_entries = INTERP_BUCKET_ENTRIES_DEFAULT;
// configuration for ai surrogate approach
bool ai = false;
bool disable_retraining = false;
static constexpr std::uint8_t AI_BACKEND_DEFAULT = 1;
std::uint8_t ai_backend = AI_BACKEND_DEFAULT; // 1 - python, 2 - naa
bool train_only_invalid = true;
int batch_size = 200 * 200;
static constexpr std::uint8_t DEFAULT_FUNCTION_CODE = 0;
std::uint8_t function_code = DEFAULT_FUNCTION_CODE;
static constexpr bool COPY_NON_REACTIVE_REGIONS = false;
bool copy_non_reactive_regions = COPY_NON_REACTIVE_REGIONS;
};
struct AIContext {
TrainingData<ai_type_t> design_buffer;
TrainingData<ai_type_t> results_buffer;
Model<ai_type_t> model;
MetaParameter<ai_type_t> meta_params;
Standardizer<ai_type_t> scaler;
std::binary_semaphore data_semaphore_write{1};
std::binary_semaphore data_semaphore_read{0};
std::binary_semaphore model_semaphore{1};
std::atomic_bool training_is_running = false;
std::unique_ptr<TrainingBackend<ai_type_t>> training_backend;
AIContext(const std::string &weights_path) : model(weights_path) {}
};