mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-16 04:48:23 +01:00
112 lines
3.7 KiB
C++
112 lines
3.7 KiB
C++
/*
|
|
** Copyright (C) 2018-2021 Alexander Lindemann, Max Luebke (University of
|
|
** Potsdam)
|
|
**
|
|
** Copyright (C) 2018-2023 Marco De Lucia, Max Luebke (GFZ Potsdam)
|
|
**
|
|
** Copyright (C) 2023-2024 Max Luebke (University of Potsdam)
|
|
**
|
|
** POET is free software; you can redistribute it and/or modify it under the
|
|
** terms of the GNU General Public License as published by the Free Software
|
|
** Foundation; either version 2 of the License, or (at your option) any later
|
|
** version.
|
|
**
|
|
** POET is distributed in the hope that it will be useful, but WITHOUT ANY
|
|
** WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
|
|
** A PARTICULAR PURPOSE. See the GNU General Public License for more details.
|
|
**
|
|
** You should have received a copy of the GNU General Public License along with
|
|
** this program; if not, write to the Free Software Foundation, Inc., 51
|
|
** Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include <atomic>
|
|
#include <cstdint>
|
|
#include <string>
|
|
#include <type_traits>
|
|
#include <vector>
|
|
|
|
#include <MetaParameter.hpp>
|
|
#include <Model.hpp>
|
|
#include <Standardizer.hpp>
|
|
#include <TrainingBackend.hpp>
|
|
#include <TrainingData.hpp>
|
|
|
|
#include <Rcpp.h>
|
|
|
|
using ai_type_t = float;
|
|
|
|
static const char *poet_version = "@POET_VERSION@";
|
|
|
|
// using the Raw string literal to avoid escaping the quotes
|
|
static const inline std::string kin_r_library = R"(@R_KIN_LIB@)";
|
|
|
|
static const inline std::string init_r_library = R"(@R_INIT_LIB@)";
|
|
static const inline std::string ai_surrogate_r_library =
|
|
R"(@R_AI_SURROGATE_LIB@)";
|
|
static const inline std::string r_runtime_parameters = "mysetup";
|
|
|
|
enum BACKEND_TYPE { PYTHON_BACKEND = 1, NAA_BACKEND, CUDA_BACKEND };
|
|
|
|
struct RuntimeParameters {
|
|
std::string out_dir;
|
|
std::vector<double> timesteps;
|
|
|
|
Rcpp::List init_params;
|
|
|
|
// MDL added to accomodate for qs::qsave/qread
|
|
bool as_rds = false;
|
|
bool as_qs = false;
|
|
std::string out_ext;
|
|
|
|
bool print_progress = false;
|
|
|
|
static constexpr std::uint32_t WORK_PACKAGE_SIZE_DEFAULT = 32;
|
|
std::uint32_t work_package_size = WORK_PACKAGE_SIZE_DEFAULT;
|
|
|
|
bool use_dht = false;
|
|
static constexpr std::uint32_t DHT_SIZE_DEFAULT = 1.5E3;
|
|
std::uint32_t dht_size = DHT_SIZE_DEFAULT;
|
|
static constexpr std::uint8_t DHT_SNAPS_DEFAULT = 0;
|
|
std::uint8_t dht_snaps = DHT_SNAPS_DEFAULT;
|
|
|
|
bool use_interp = false;
|
|
static constexpr std::uint32_t INTERP_SIZE_DEFAULT = 100;
|
|
std::uint32_t interp_size = INTERP_SIZE_DEFAULT;
|
|
static constexpr std::uint32_t INTERP_MIN_ENTRIES_DEFAULT = 5;
|
|
std::uint32_t interp_min_entries = INTERP_MIN_ENTRIES_DEFAULT;
|
|
static constexpr std::uint32_t INTERP_BUCKET_ENTRIES_DEFAULT = 20;
|
|
std::uint32_t interp_bucket_entries = INTERP_BUCKET_ENTRIES_DEFAULT;
|
|
|
|
// configuration for ai surrogate approach
|
|
bool ai = false;
|
|
bool disable_retraining = false;
|
|
static constexpr std::uint8_t AI_BACKEND_DEFAULT = 1;
|
|
std::uint8_t ai_backend = AI_BACKEND_DEFAULT; // 1 - python, 2 - naa
|
|
bool train_only_invalid = true;
|
|
int batch_size = 200 * 200;
|
|
static constexpr std::uint8_t DEFAULT_FUNCTION_CODE = 0;
|
|
std::uint8_t function_code = DEFAULT_FUNCTION_CODE;
|
|
|
|
static constexpr bool COPY_NON_REACTIVE_REGIONS = false;
|
|
bool copy_non_reactive_regions = COPY_NON_REACTIVE_REGIONS;
|
|
};
|
|
|
|
struct AIContext {
|
|
TrainingData<ai_type_t> design_buffer;
|
|
TrainingData<ai_type_t> results_buffer;
|
|
Model<ai_type_t> model;
|
|
MetaParameter<ai_type_t> meta_params;
|
|
Standardizer<ai_type_t> scaler;
|
|
|
|
std::binary_semaphore data_semaphore_write{1};
|
|
std::binary_semaphore data_semaphore_read{0};
|
|
std::binary_semaphore model_semaphore{1};
|
|
std::atomic_bool training_is_running = false;
|
|
std::unique_ptr<TrainingBackend<ai_type_t>> training_backend;
|
|
|
|
AIContext(const std::string &weights_path) : model(weights_path) {}
|
|
};
|