mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-15 12:28:22 +01:00
Merge branch 'naa-naaice' of git.gfz-potsdam.de:naaice/poet into naa-naaice
This commit is contained in:
commit
66ef8f9a34
6
.gitmodules
vendored
6
.gitmodules
vendored
@ -5,9 +5,9 @@
|
|||||||
[submodule "ext/iphreeqc"]
|
[submodule "ext/iphreeqc"]
|
||||||
path = ext/iphreeqc
|
path = ext/iphreeqc
|
||||||
url = ../iphreeqc.git
|
url = ../iphreeqc.git
|
||||||
[submodule "ext/ai-surroagate"]
|
|
||||||
path = ext/ai-surroagate
|
|
||||||
url = git@git.gfz-potsdam.de:naaice/ai-surrogate-poet.git
|
|
||||||
[submodule "ext/ai-benchmarks-utils"]
|
[submodule "ext/ai-benchmarks-utils"]
|
||||||
path = ext/ai-benchmarks-utils
|
path = ext/ai-benchmarks-utils
|
||||||
url = git@git.gfz-potsdam.de:naaice/ai-benchmarks-utils.git
|
url = git@git.gfz-potsdam.de:naaice/ai-benchmarks-utils.git
|
||||||
|
[submodule "ext/ai-surrogate-poet"]
|
||||||
|
path = ext/ai-surrogate-poet
|
||||||
|
url = git@git.gfz-potsdam.de:naaice/ai-surrogate-poet.git
|
||||||
|
|||||||
@ -33,6 +33,11 @@ set(TUG_ENABLE_TESTING OFF CACHE BOOL "" FORCE)
|
|||||||
add_subdirectory(ext/tug EXCLUDE_FROM_ALL)
|
add_subdirectory(ext/tug EXCLUDE_FROM_ALL)
|
||||||
add_subdirectory(ext/iphreeqc EXCLUDE_FROM_ALL)
|
add_subdirectory(ext/iphreeqc EXCLUDE_FROM_ALL)
|
||||||
|
|
||||||
|
# AI/NAA specific includes TODO: add option flags
|
||||||
|
add_subdirectory(ext/ai-surrogate EXCLUDE_FROM_ALL)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
option(POET_ENABLE_TESTING "Build test suite for POET" OFF)
|
option(POET_ENABLE_TESTING "Build test suite for POET" OFF)
|
||||||
|
|
||||||
if (POET_ENABLE_TESTING)
|
if (POET_ENABLE_TESTING)
|
||||||
|
|||||||
@ -63,6 +63,11 @@ set_valid_predictions <- function(temp_field, prediction, validity) {
|
|||||||
return(temp_field)
|
return(temp_field)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
get_invalid_values <- function(df, validity) {
|
||||||
|
return(df[validity == 0, ])
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
training_step <- function(model, predictor, target, validity) {
|
training_step <- function(model, predictor, target, validity) {
|
||||||
msgm("Training:")
|
msgm("Training:")
|
||||||
|
|
||||||
|
|||||||
@ -1,18 +1,5 @@
|
|||||||
## Time-stamp: "Last modified 2024-05-30 13:27:06 delucia"
|
|
||||||
|
|
||||||
## load a pretrained model from tensorflow file
|
|
||||||
## Use the global variable "ai_surrogate_base_path" when using file paths
|
|
||||||
## relative to the input script
|
|
||||||
initiate_model <- function() {
|
|
||||||
require(keras3)
|
|
||||||
require(tensorflow)
|
|
||||||
init_model <- normalizePath(paste0(ai_surrogate_base_path,
|
|
||||||
"barite_50ai_all.keras"))
|
|
||||||
Model <- keras3::load_model(init_model)
|
|
||||||
msgm("Loaded model:")
|
|
||||||
print(str(Model))
|
|
||||||
return(Model)
|
|
||||||
}
|
|
||||||
|
|
||||||
scale_min_max <- function(x, min, max, backtransform) {
|
scale_min_max <- function(x, min, max, backtransform) {
|
||||||
if (backtransform) {
|
if (backtransform) {
|
||||||
@ -22,6 +9,22 @@ scale_min_max <- function(x, min, max, backtransform) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
scale_standardizer <- function(x, mean, scale, backtransform) {
|
||||||
|
if(backtransform){
|
||||||
|
return(x * scale + mean)
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
return((x-mean) / scale)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
standard <- list(mean = c(H = 111.01243361730982, O= 55.50673140754027, Ba= 0.0016161137065825058,
|
||||||
|
Cl= 0.0534503766678322, S=0.00012864849674669584, Sr=0.0252377348949622,
|
||||||
|
Barite_kin=0.05292312117000998, Celestite_kin=0.9475491659328229),
|
||||||
|
scale = c(H=1.0, O=0.00048139729680698453, Ba=0.008945717576237102, Cl=0.03587363709464328,
|
||||||
|
S=0.00012035100591827131, Sr=0.01523052668095922, Barite_kin=0.21668648247230615,
|
||||||
|
Celestite_kin=0.21639449682671968))
|
||||||
|
|
||||||
minmax <- list(min = c(H = 111.012433592824, O = 55.5062185549492, Charge = -3.1028354471876e-08,
|
minmax <- list(min = c(H = 111.012433592824, O = 55.5062185549492, Charge = -3.1028354471876e-08,
|
||||||
Ba = 1.87312878574393e-141, Cl = 0, `S(6)` = 4.24227510643685e-07,
|
Ba = 1.87312878574393e-141, Cl = 0, `S(6)` = 4.24227510643685e-07,
|
||||||
Sr = 0.00049382996130541, Barite = 0.000999542409828586, Celestite = 0.244801877115968),
|
Sr = 0.00049382996130541, Barite = 0.000999542409828586, Celestite = 0.244801877115968),
|
||||||
@ -30,14 +33,19 @@ minmax <- list(min = c(H = 111.012433592824, O = 55.5062185549492, Charge = -3.1
|
|||||||
Sr = 0.0558680070692722, Barite = 0.756779139057097, Celestite = 1.00075422160624
|
Sr = 0.0558680070692722, Barite = 0.756779139057097, Celestite = 1.00075422160624
|
||||||
))
|
))
|
||||||
|
|
||||||
|
ai_surrogate_species_input = c("H", "O", "Ba", "Cl", "S", "Sr", "Barite_kin", "Celestite_kin")
|
||||||
|
ai_surrogate_species_output = c("O", "Ba", "S", "Sr", "Barite_kin", "Celestite_kin")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
preprocess <- function(df) {
|
preprocess <- function(df) {
|
||||||
if (!is.data.frame(df))
|
if (!is.data.frame(df))
|
||||||
df <- as.data.frame(df, check.names = FALSE)
|
df <- as.data.frame(df, check.names = FALSE)
|
||||||
|
|
||||||
as.data.frame(lapply(colnames(df),
|
as.data.frame(lapply(colnames(df),
|
||||||
function(x) scale_min_max(x=df[x],
|
function(x) scale_standardizer(x=df[x],
|
||||||
min=minmax$min[x],
|
mean=standard$mean[x],
|
||||||
max=minmax$max[x],
|
scale=standard$scale[x],
|
||||||
backtransform=FALSE)),
|
backtransform=FALSE)),
|
||||||
check.names = FALSE)
|
check.names = FALSE)
|
||||||
}
|
}
|
||||||
@ -47,23 +55,25 @@ postprocess <- function(df) {
|
|||||||
df <- as.data.frame(df, check.names = FALSE)
|
df <- as.data.frame(df, check.names = FALSE)
|
||||||
|
|
||||||
as.data.frame(lapply(colnames(df),
|
as.data.frame(lapply(colnames(df),
|
||||||
function(x) scale_min_max(x=df[x],
|
function(x) scale_standardizer(x=df[x],
|
||||||
min=minmax$min[x],
|
mean=standard$mean[x],
|
||||||
max=minmax$max[x],
|
scale=standard$scale[x],
|
||||||
backtransform=TRUE)),
|
backtransform=TRUE)),
|
||||||
check.names = FALSE)
|
check.names = FALSE)
|
||||||
}
|
}
|
||||||
|
|
||||||
mass_balance <- function(predictors, prediction) {
|
mass_balance <- function(predictors, prediction) {
|
||||||
dBa <- abs(prediction$Ba + prediction$Barite -
|
dBa <- abs(prediction$Ba + prediction$Barite_kin -
|
||||||
predictors$Ba - predictors$Barite)
|
predictors$Ba - predictors$Barite_kin)
|
||||||
dSr <- abs(prediction$Sr + prediction$Celestite -
|
dSr <- abs(prediction$Sr + prediction$Celestite_kin -
|
||||||
predictors$Sr - predictors$Celestite)
|
predictors$Sr - predictors$Celestite_kin)
|
||||||
return(dBa + dSr)
|
dS <- abs(prediction$S + prediction$Celestite_kin + prediction$Barite_kin -
|
||||||
|
predictors$S - predictors$Celestite_kin - predictors$Barite_kin)
|
||||||
|
return(dBa + dSr + dS)
|
||||||
}
|
}
|
||||||
|
|
||||||
validate_predictions <- function(predictors, prediction) {
|
validate_predictions <- function(predictors, prediction) {
|
||||||
epsilon <- 1E-7
|
epsilon <- 1E-5
|
||||||
mb <- mass_balance(predictors, prediction)
|
mb <- mass_balance(predictors, prediction)
|
||||||
msgm("Mass balance mean:", mean(mb))
|
msgm("Mass balance mean:", mean(mb))
|
||||||
msgm("Mass balance variance:", var(mb))
|
msgm("Mass balance variance:", var(mb))
|
||||||
@ -72,19 +82,3 @@ validate_predictions <- function(predictors, prediction) {
|
|||||||
sum(ret))
|
sum(ret))
|
||||||
return(ret)
|
return(ret)
|
||||||
}
|
}
|
||||||
|
|
||||||
training_step <- function(model, predictor, target, validity) {
|
|
||||||
msgm("Starting incremental training:")
|
|
||||||
|
|
||||||
## x <- as.matrix(predictor)
|
|
||||||
## y <- as.matrix(target[colnames(x)])
|
|
||||||
|
|
||||||
history <- model %>% keras3::fit(x = data.matrix(predictor),
|
|
||||||
y = data.matrix(target),
|
|
||||||
epochs = 10, verbose=1)
|
|
||||||
|
|
||||||
keras3::save_model(model,
|
|
||||||
filepath = paste0(out_dir, "/current_model.keras"),
|
|
||||||
overwrite=TRUE)
|
|
||||||
return(model)
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1 +0,0 @@
|
|||||||
Subproject commit 1a2dfc6a48fb82b86d142c297db539315d135797
|
|
||||||
1
ext/ai-surrogate-poet
Submodule
1
ext/ai-surrogate-poet
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit 112c8ff1a88f47a73909724e31227173fd50126a
|
||||||
150
poet.yml
Normal file
150
poet.yml
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
name: poet
|
||||||
|
channels:
|
||||||
|
- defaults
|
||||||
|
- conda-forge
|
||||||
|
- https://repo.anaconda.com/pkgs/main
|
||||||
|
- https://repo.anaconda.com/pkgs/r
|
||||||
|
dependencies:
|
||||||
|
- _libgcc_mutex=0.1=conda_forge
|
||||||
|
- _openmp_mutex=4.5=2_gnu
|
||||||
|
- attr=2.5.2=h39aace5_0
|
||||||
|
- binutils_impl_linux-64=2.43=h4bf12b8_5
|
||||||
|
- binutils_linux-64=2.43=h4852527_5
|
||||||
|
- bzip2=1.0.8=h5eee18b_6
|
||||||
|
- c-ares=1.34.5=hb9d3cd8_0
|
||||||
|
- ca-certificates=2025.11.12=hbd8a1cb_0
|
||||||
|
- cached-property=1.5.2=py_0
|
||||||
|
- cmake=4.1.2=hc946e07_0
|
||||||
|
- eigen=3.4.0=h171cf75_1
|
||||||
|
- expat=2.7.1=h6a678d5_0
|
||||||
|
- gcc_impl_linux-64=13.3.0=h1e990d8_2
|
||||||
|
- gcc_linux-64=13.3.0=h6f18a23_11
|
||||||
|
- gettext=0.21.0=hedfda30_2
|
||||||
|
- gxx_impl_linux-64=13.3.0=hae580e1_2
|
||||||
|
- gxx_linux-64=13.3.0=hb14504d_11
|
||||||
|
- h5py=3.14.0=nompi_py313hfaf8fd4_100
|
||||||
|
- hdf5=1.14.6=nompi_h6e4c0c1_103
|
||||||
|
- highfive=2.10.1=he6560a2_2
|
||||||
|
- icu=73.1=h6a678d5_0
|
||||||
|
- kernel-headers_linux-64=3.10.0=he073ed8_18
|
||||||
|
- krb5=1.21.3=h143b758_0
|
||||||
|
- ld_impl_linux-64=2.43=h712a8e2_5
|
||||||
|
- libaec=1.1.4=h3f801dc_0
|
||||||
|
- libblas=3.9.0=32_h59b9bed_openblas
|
||||||
|
- libcap=2.77=h3ff7636_0
|
||||||
|
- libcblas=3.9.0=32_he106b2a_openblas
|
||||||
|
- libcurl=8.16.0=heebcbe5_0
|
||||||
|
- libedit=3.1.20230828=h5eee18b_0
|
||||||
|
- libev=4.33=h7f8727e_1
|
||||||
|
- libevent=2.1.12=hf998b51_1
|
||||||
|
- libexpat=2.7.1=hecca717_0
|
||||||
|
- libfabric=2.3.1=ha770c72_1
|
||||||
|
- libfabric1=2.3.1=h6c8fc0a_1
|
||||||
|
- libffi=3.4.4=h6a678d5_1
|
||||||
|
- libgcc=15.1.0=h767d61c_2
|
||||||
|
- libgcc-devel_linux-64=13.3.0=hc03c837_102
|
||||||
|
- libgcc-ng=15.1.0=h69a702a_2
|
||||||
|
- libgfortran=15.1.0=h69a702a_2
|
||||||
|
- libgfortran-ng=15.1.0=h69a702a_2
|
||||||
|
- libgfortran5=15.1.0=hcea5267_2
|
||||||
|
- libgomp=15.1.0=h767d61c_2
|
||||||
|
- libhwloc=2.12.1=default_h3d81e11_1000
|
||||||
|
- libiconv=1.16=h5eee18b_3
|
||||||
|
- libidn2=2.3.8=hf80d704_0
|
||||||
|
- liblapack=3.9.0=32_h7ac8fdf_openblas
|
||||||
|
- liblzma=5.8.1=hb9d3cd8_2
|
||||||
|
- liblzma-devel=5.8.1=hb9d3cd8_2
|
||||||
|
- libmpdec=4.0.0=h5eee18b_0
|
||||||
|
- libnghttp2=1.64.0=h161d5f1_0
|
||||||
|
- libnl=3.11.0=hb9d3cd8_0
|
||||||
|
- libopenblas=0.3.30=pthreads_h94d23a6_0
|
||||||
|
- libpmix=5.0.8=h4bd6b51_2
|
||||||
|
- libsanitizer=13.3.0=he8ea267_2
|
||||||
|
- libsqlite=3.50.2=h6cd9bfd_0
|
||||||
|
- libssh2=1.11.1=hcf80075_0
|
||||||
|
- libstdcxx=15.1.0=h8f9b012_2
|
||||||
|
- libstdcxx-devel_linux-64=13.3.0=hc03c837_102
|
||||||
|
- libstdcxx-ng=15.1.0=h4852527_2
|
||||||
|
- libsystemd0=257.10=hd0affe5_2
|
||||||
|
- libudev1=257.10=hd0affe5_2
|
||||||
|
- libunistring=1.3=hb25bd0a_0
|
||||||
|
- libuuid=2.38.1=h0b41bf4_0
|
||||||
|
- libuv=1.48.0=h5eee18b_0
|
||||||
|
- libxcb=1.17.0=h9b100fa_0
|
||||||
|
- libxml2=2.13.9=h2c43086_0
|
||||||
|
- libzlib=1.3.1=hb9d3cd8_2
|
||||||
|
- mpi=1.0.1=openmpi
|
||||||
|
- ncurses=6.5=h2d0b736_3
|
||||||
|
- numpy=2.3.0=py313h17eae1a_0
|
||||||
|
- openmpi=5.0.8=h2fe1745_108
|
||||||
|
- openssl=3.6.0=h26f9b46_0
|
||||||
|
- pip=25.1=pyhc872135_2
|
||||||
|
- pthread-stubs=0.3=h0ce48e5_1
|
||||||
|
- pybind11=2.13.6=py313hdb19cb5_1
|
||||||
|
- pybind11-global=2.13.6=py313hdb19cb5_1
|
||||||
|
- python=3.13.2=hf636f53_101_cp313
|
||||||
|
- python_abi=3.13=0_cp313
|
||||||
|
- rdma-core=60.0=hecca717_0
|
||||||
|
- readline=8.2=h5eee18b_0
|
||||||
|
- rhash=1.4.6=ha914fed_0
|
||||||
|
- setuptools=78.1.1=py313h06a4308_0
|
||||||
|
- sqlite=3.31.1=h7b6447c_0
|
||||||
|
- sysroot_linux-64=2.17=h0157908_18
|
||||||
|
- tk=8.6.13=noxft_hd72426e_102
|
||||||
|
- ucc=1.6.0=hb729f83_0
|
||||||
|
- ucx=1.19.0=h63b5c0b_5
|
||||||
|
- wheel=0.45.1=py313h06a4308_0
|
||||||
|
- xorg-libx11=1.8.12=h9b100fa_1
|
||||||
|
- xorg-libxau=1.0.12=h9b100fa_0
|
||||||
|
- xorg-libxdmcp=1.1.5=h9b100fa_0
|
||||||
|
- xorg-xorgproto=2024.1=h5eee18b_1
|
||||||
|
- xz=5.8.1=hbcc6ac9_2
|
||||||
|
- xz-gpl-tools=5.8.1=hbcc6ac9_2
|
||||||
|
- xz-tools=5.8.1=hb9d3cd8_2
|
||||||
|
- zlib=1.3.1=hb9d3cd8_2
|
||||||
|
- zstd=1.5.7=hb8e6e7a_2
|
||||||
|
- pip:
|
||||||
|
- absl-py==2.3.1
|
||||||
|
- ai-benchmarks-utils==0.1.0
|
||||||
|
- astunparse==1.6.3
|
||||||
|
- certifi==2025.11.12
|
||||||
|
- charset-normalizer==3.4.4
|
||||||
|
- flatbuffers==25.9.23
|
||||||
|
- gast==0.6.0
|
||||||
|
- google-pasta==0.2.0
|
||||||
|
- grpcio==1.76.0
|
||||||
|
- idna==3.11
|
||||||
|
- joblib==1.5.2
|
||||||
|
- keras==3.12.0
|
||||||
|
- libclang==18.1.1
|
||||||
|
- markdown==3.10
|
||||||
|
- markdown-it-py==4.0.0
|
||||||
|
- markupsafe==3.0.3
|
||||||
|
- mdurl==0.1.2
|
||||||
|
- ml-dtypes==0.5.4
|
||||||
|
- namex==0.1.0
|
||||||
|
- opt-einsum==3.4.0
|
||||||
|
- optree==0.18.0
|
||||||
|
- packaging==25.0
|
||||||
|
- pandas==2.3.3
|
||||||
|
- pillow==12.0.0
|
||||||
|
- protobuf==6.33.1
|
||||||
|
- pygments==2.19.2
|
||||||
|
- python-dateutil==2.9.0.post0
|
||||||
|
- pytz==2025.2
|
||||||
|
- requests==2.32.5
|
||||||
|
- rich==14.2.0
|
||||||
|
- scikit-learn==1.7.2
|
||||||
|
- scipy==1.16.3
|
||||||
|
- six==1.17.0
|
||||||
|
- tensorboard==2.20.0
|
||||||
|
- tensorboard-data-server==0.7.2
|
||||||
|
- tensorflow==2.20.0
|
||||||
|
- termcolor==3.2.0
|
||||||
|
- threadpoolctl==3.6.0
|
||||||
|
- typing-extensions==4.15.0
|
||||||
|
- tzdata==2025.2
|
||||||
|
- urllib3==2.5.0
|
||||||
|
- werkzeug==3.1.3
|
||||||
|
- wrapt==2.0.1
|
||||||
|
prefix: /mnt/scratch/miniconda3/envs/poet-dummy
|
||||||
190
src/poet.cpp
190
src/poet.cpp
@ -43,6 +43,12 @@
|
|||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include <Model.hpp>
|
||||||
|
#include <NAABackend.hpp>
|
||||||
|
#include <PythonBackend.hpp>
|
||||||
|
#include <TrainingBackend.hpp>
|
||||||
|
#include <TrainingData.hpp>
|
||||||
|
|
||||||
#include <CLI/CLI.hpp>
|
#include <CLI/CLI.hpp>
|
||||||
|
|
||||||
#include <poet.hpp>
|
#include <poet.hpp>
|
||||||
@ -145,8 +151,16 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) {
|
|||||||
->check(CLI::PositiveNumber)
|
->check(CLI::PositiveNumber)
|
||||||
->default_val(RuntimeParameters::INTERP_BUCKET_ENTRIES_DEFAULT);
|
->default_val(RuntimeParameters::INTERP_BUCKET_ENTRIES_DEFAULT);
|
||||||
|
|
||||||
app.add_flag("--ai-surrogate", params.use_ai_surrogate,
|
auto *ai_option_group =
|
||||||
"Enable AI surrogate for chemistry module");
|
app.add_option_group("ai_surrogate", "AI Surrogate related options");
|
||||||
|
|
||||||
|
ai_option_group->add_flag("--ai", params.ai,
|
||||||
|
"Enable AI surrogate for chemistry module");
|
||||||
|
ai_option_group
|
||||||
|
->add_option("--ai-backend", params.ai_backend,
|
||||||
|
"Desired ai backend (0: python (keras), 1: naa, 2: cuda)")
|
||||||
|
->check(CLI::PositiveNumber)
|
||||||
|
->default_val(RuntimeParameters::AI_BACKEND_DEFAULT);
|
||||||
|
|
||||||
app.add_flag("--rds", params.as_rds,
|
app.add_flag("--rds", params.as_rds,
|
||||||
"Save output as .rds file instead of default .qs2");
|
"Save output as .rds file instead of default .qs2");
|
||||||
@ -191,7 +205,7 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) {
|
|||||||
MSG("Output format/extension is " + params.out_ext);
|
MSG("Output format/extension is " + params.out_ext);
|
||||||
MSG("Work Package Size: " + std::to_string(params.work_package_size));
|
MSG("Work Package Size: " + std::to_string(params.work_package_size));
|
||||||
MSG("DHT is " + BOOL_PRINT(params.use_dht));
|
MSG("DHT is " + BOOL_PRINT(params.use_dht));
|
||||||
MSG("AI Surrogate is " + BOOL_PRINT(params.use_ai_surrogate));
|
MSG("AI Surrogate is " + BOOL_PRINT(params.ai));
|
||||||
|
|
||||||
if (params.use_dht) {
|
if (params.use_dht) {
|
||||||
// MSG("DHT strategy is " + std::to_string(simparams.dht_strategy));
|
// MSG("DHT strategy is " + std::to_string(simparams.dht_strategy));
|
||||||
@ -236,7 +250,6 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) {
|
|||||||
// R["dht_log"] = simparams.dht_log;
|
// R["dht_log"] = simparams.dht_log;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|
||||||
Rcpp::List init_params_(ReadRObj_R(init_file));
|
Rcpp::List init_params_(ReadRObj_R(init_file));
|
||||||
params.init_params = init_params_;
|
params.init_params = init_params_;
|
||||||
|
|
||||||
@ -290,6 +303,41 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
|||||||
}
|
}
|
||||||
R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps());
|
R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps());
|
||||||
|
|
||||||
|
std::unique_ptr<AIContext> ai_ctx = nullptr;
|
||||||
|
|
||||||
|
if (params.ai) {
|
||||||
|
|
||||||
|
ai_ctx = std::make_unique<AIContext>(
|
||||||
|
"/mnt/scratch/signer/poet/bench/barite/barite_trained.weights.h5");
|
||||||
|
R.parseEval(
|
||||||
|
"mean <- as.numeric(standard$mean[ai_surrogate_species_output])");
|
||||||
|
R.parseEval(
|
||||||
|
"scale <- as.numeric(standard$scale[ai_surrogate_species_output])");
|
||||||
|
|
||||||
|
std::vector<float> mean = R["mean"];
|
||||||
|
std::vector<float> scale = R["scale"];
|
||||||
|
|
||||||
|
ai_ctx->scaler.set_scaler(mean, scale);
|
||||||
|
|
||||||
|
// initialzie training backens only if retraining is desired
|
||||||
|
if (params.ai_backend == PYTHON_BACKEND) {
|
||||||
|
MSG("AI Surrogate with Python/keras backend enabled.")
|
||||||
|
// auto model = Python<ai_type_t>();
|
||||||
|
} else if (params.ai_backend == NAA_BACKEND) {
|
||||||
|
MSG("AI Surrogate with NAA backend enabled.")
|
||||||
|
ai_ctx->training_backend =
|
||||||
|
std::make_unique<NAABackend<ai_type_t>>(20 * params.batch_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!params.disable_retraining) {
|
||||||
|
ai_ctx->training_backend->training_thread(
|
||||||
|
ai_ctx->design_buffer, ai_ctx->results_buffer, ai_ctx->model,
|
||||||
|
ai_ctx->meta_params, ai_ctx->scaler, ai_ctx->data_semaphore_write,
|
||||||
|
ai_ctx->data_semaphore_read, ai_ctx->model_semaphore,
|
||||||
|
ai_ctx->training_is_running, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* SIMULATION LOOP */
|
/* SIMULATION LOOP */
|
||||||
double dSimTime{0};
|
double dSimTime{0};
|
||||||
for (uint32_t iter = 1; iter < maxiter + 1; iter++) {
|
for (uint32_t iter = 1; iter < maxiter + 1; iter++) {
|
||||||
@ -311,7 +359,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
|||||||
chem.getField().update(diffusion.getField());
|
chem.getField().update(diffusion.getField());
|
||||||
|
|
||||||
// MSG("Chemistry start");
|
// MSG("Chemistry start");
|
||||||
if (params.use_ai_surrogate) {
|
if (params.ai) {
|
||||||
double ai_start_t = MPI_Wtime();
|
double ai_start_t = MPI_Wtime();
|
||||||
// Save current values from the tug field as predictor for the ai step
|
// Save current values from the tug field as predictor for the ai step
|
||||||
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
|
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
|
||||||
@ -319,42 +367,63 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
|||||||
std::string("predictors <- setNames(data.frame(matrix(TMP, nrow=" +
|
std::string("predictors <- setNames(data.frame(matrix(TMP, nrow=" +
|
||||||
std::to_string(chem.getField().GetRequestedVecSize()) +
|
std::to_string(chem.getField().GetRequestedVecSize()) +
|
||||||
")), TMP_PROPS)"));
|
")), TMP_PROPS)"));
|
||||||
R.parseEval("predictors <- predictors[ai_surrogate_species]");
|
|
||||||
|
|
||||||
// Apply preprocessing
|
R.parseEval("predictors <- predictors[ai_surrogate_species_input]");
|
||||||
MSG("AI Preprocessing");
|
|
||||||
R.parseEval("predictors_scaled <- preprocess(predictors)");
|
R.parseEval("predictors_scaled <- preprocess(predictors)");
|
||||||
|
std::vector<std::vector<float>> predictors_scaled =
|
||||||
|
R["predictors_scaled"];
|
||||||
|
|
||||||
// Predict
|
// FIXME: double/float conversion
|
||||||
MSG("AI Prediction");
|
std::vector<float> predictions_scaled = ai_ctx->model.predict(
|
||||||
R.parseEval(
|
predictors_scaled, params.batch_size, ai_ctx->model_semaphore);
|
||||||
"aipreds_scaled <- prediction_step(model, predictors_scaled)");
|
|
||||||
|
|
||||||
// Apply postprocessing
|
int n_samples = R.parseEval("nrow(predictors)");
|
||||||
MSG("AI Postprocessing");
|
int n_output_features = ai_ctx->model.weight_matrices.back().cols();
|
||||||
R.parseEval("aipreds <- postprocess(aipreds_scaled)");
|
std::cout << "n_output_features: " << n_output_features << std::endl;
|
||||||
|
std::vector<double> predictions_scaled_double(predictions_scaled.begin(),
|
||||||
|
predictions_scaled.end());
|
||||||
|
R["TMP"] = predictions_scaled_double;
|
||||||
|
R["n_samples"] = n_samples;
|
||||||
|
R["n_output"] = n_output_features;
|
||||||
|
|
||||||
|
R.parseEval("predictions_scaled <- setNames(data.frame(matrix(TMP, "
|
||||||
|
"nrow=n_samples, ncol=n_output, byrow=TRUE)), "
|
||||||
|
"ai_surrogate_species_output)");
|
||||||
|
// R.parseEval("print(head(predictions_scaled))");
|
||||||
|
R.parseEval("predictions <- postprocess(predictions_scaled)");
|
||||||
|
// R.parseEval("print(head(predictions))");
|
||||||
|
|
||||||
// Validate prediction and write valid predictions to chem field
|
|
||||||
MSG("AI Validation");
|
MSG("AI Validation");
|
||||||
R.parseEval(
|
|
||||||
"validity_vector <- validate_predictions(predictors, aipreds)");
|
// FIXME: (mass balance plausible?)
|
||||||
|
R.parseEval("validity_vector <- validate_predictions(predictors, "
|
||||||
|
"predictions) ");
|
||||||
|
|
||||||
|
R.parseEval("print(head(validity_vector))");
|
||||||
|
|
||||||
MSG("AI Marking accepted");
|
MSG("AI Marking accepted");
|
||||||
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
|
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
|
||||||
|
|
||||||
MSG("AI TempField");
|
MSG("AI TempField");
|
||||||
std::vector<std::vector<double>> RTempField =
|
R.parseEval("print(ai_surrogate_species_output)");
|
||||||
R.parseEval("set_valid_predictions(predictors,\
|
// R.parseEval("print(head(predictors))");
|
||||||
aipreds,\
|
std::vector<std::vector<double>> RTempField = R.parseEval(
|
||||||
|
"set_valid_predictions(predictors[ai_surrogate_species_output],\
|
||||||
|
predictions,\
|
||||||
validity_vector)");
|
validity_vector)");
|
||||||
|
|
||||||
MSG("AI Set Field");
|
MSG("AI Set Field");
|
||||||
Field predictions_field =
|
Field predictions_field = Field(
|
||||||
Field(R.parseEval("nrow(predictors)"), RTempField,
|
R.parseEval("nrow(predictors)"), RTempField,
|
||||||
R.parseEval("colnames(predictors)"));
|
R.parseEval(
|
||||||
|
"colnames(predictors[ai_surrogate_species_output])")); // FIXME:
|
||||||
|
// is this
|
||||||
|
// correct?
|
||||||
|
|
||||||
MSG("AI Update");
|
MSG("AI Update");
|
||||||
chem.getField().update(predictions_field);
|
chem.getField().update(predictions_field);
|
||||||
|
|
||||||
double ai_end_t = MPI_Wtime();
|
double ai_end_t = MPI_Wtime();
|
||||||
R["ai_prediction_time"] = ai_end_t - ai_start_t;
|
R["ai_prediction_time"] = ai_end_t - ai_start_t;
|
||||||
}
|
}
|
||||||
@ -362,7 +431,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
|||||||
chem.simulate(dt);
|
chem.simulate(dt);
|
||||||
|
|
||||||
/* AI surrogate iterative training*/
|
/* AI surrogate iterative training*/
|
||||||
if (params.use_ai_surrogate) {
|
if (params.ai == true && params.disable_retraining == false) {
|
||||||
double ai_start_t = MPI_Wtime();
|
double ai_start_t = MPI_Wtime();
|
||||||
|
|
||||||
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
|
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
|
||||||
@ -370,14 +439,55 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
|||||||
std::string("targets <- setNames(data.frame(matrix(TMP, nrow=" +
|
std::string("targets <- setNames(data.frame(matrix(TMP, nrow=" +
|
||||||
std::to_string(chem.getField().GetRequestedVecSize()) +
|
std::to_string(chem.getField().GetRequestedVecSize()) +
|
||||||
")), TMP_PROPS)"));
|
")), TMP_PROPS)"));
|
||||||
R.parseEval("targets <- targets[ai_surrogate_species]");
|
|
||||||
|
|
||||||
// TODO: Check how to get the correct columns
|
R.parseEval("predictors_retraining <- "
|
||||||
R.parseEval("target_scaled <- preprocess(targets)");
|
"get_invalid_values(predictors_scaled, validity_vector)");
|
||||||
|
R.parseEval("targets_retraining <- "
|
||||||
|
"get_invalid_values(targets[ai_surrogate_species_output], "
|
||||||
|
"validity_vector)");
|
||||||
|
R.parseEval("targets_retraining <- preprocess(targets_retraining)");
|
||||||
|
|
||||||
|
std::vector<std::vector<float>> predictors_retraining =
|
||||||
|
R["predictors_retraining"];
|
||||||
|
std::vector<std::vector<float>> targets_retraining =
|
||||||
|
R["targets_retraining"];
|
||||||
|
|
||||||
|
MSG("AI: add invalid data to buffer");
|
||||||
|
|
||||||
|
ai_ctx->data_semaphore_write.acquire();
|
||||||
|
|
||||||
|
std::cout << "size of predictors " << predictors_retraining[0].size()
|
||||||
|
<< std::endl;
|
||||||
|
std::cout << "size of targets " << targets_retraining[0].size()
|
||||||
|
<< std::endl;
|
||||||
|
|
||||||
|
ai_ctx->design_buffer.addData(predictors_retraining);
|
||||||
|
ai_ctx->results_buffer.addData(targets_retraining);
|
||||||
|
|
||||||
|
size_t elements_design_buffer =
|
||||||
|
ai_ctx->design_buffer.getSize() /
|
||||||
|
(predictors_retraining.size() * sizeof(float));
|
||||||
|
size_t elements_results_buffer =
|
||||||
|
ai_ctx->results_buffer.getSize() /
|
||||||
|
(targets_retraining.size() * sizeof(float));
|
||||||
|
|
||||||
|
std::cout << "design_buffer_size: " << elements_design_buffer
|
||||||
|
<< std::endl;
|
||||||
|
std::cout << "results_buffer_size: " << elements_results_buffer
|
||||||
|
<< std::endl;
|
||||||
|
|
||||||
|
if (elements_design_buffer >= 20 * params.batch_size &&
|
||||||
|
elements_results_buffer >= 20 * params.batch_size &&
|
||||||
|
ai_ctx->training_is_running == false) {
|
||||||
|
ai_ctx->data_semaphore_read.release();
|
||||||
|
} else if (ai_ctx->training_is_running == true) {
|
||||||
|
MSG("Training is currently running");
|
||||||
|
ai_ctx->data_semaphore_write.release();
|
||||||
|
} else {
|
||||||
|
MSG("Not enough data for retraining");
|
||||||
|
ai_ctx->data_semaphore_write.release();
|
||||||
|
}
|
||||||
|
|
||||||
MSG("AI: incremental training");
|
|
||||||
R.parseEval("model <- training_step(model, predictors_scaled, "
|
|
||||||
"target_scaled, validity_vector)");
|
|
||||||
double ai_end_t = MPI_Wtime();
|
double ai_end_t = MPI_Wtime();
|
||||||
R["ai_training_time"] = ai_end_t - ai_start_t;
|
R["ai_training_time"] = ai_end_t - ai_start_t;
|
||||||
}
|
}
|
||||||
@ -402,6 +512,10 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
|||||||
|
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
|
|
||||||
|
if (!params.disable_retraining) {
|
||||||
|
ai_ctx->training_backend->stop_training(ai_ctx->data_semaphore_read);
|
||||||
|
}
|
||||||
|
|
||||||
Rcpp::List chem_profiling;
|
Rcpp::List chem_profiling;
|
||||||
chem_profiling["simtime"] = chem.GetChemistryTime();
|
chem_profiling["simtime"] = chem.GetChemistryTime();
|
||||||
chem_profiling["loop"] = chem.GetMasterLoopTime();
|
chem_profiling["loop"] = chem.GetMasterLoopTime();
|
||||||
@ -593,7 +707,7 @@ int main(int argc, char *argv[]) {
|
|||||||
run_params.interp_bucket_entries,
|
run_params.interp_bucket_entries,
|
||||||
run_params.interp_size,
|
run_params.interp_size,
|
||||||
run_params.interp_min_entries,
|
run_params.interp_min_entries,
|
||||||
run_params.use_ai_surrogate};
|
run_params.ai};
|
||||||
|
|
||||||
chemistry.masterEnableSurrogates(surr_setup);
|
chemistry.masterEnableSurrogates(surr_setup);
|
||||||
|
|
||||||
@ -613,13 +727,15 @@ int main(int argc, char *argv[]) {
|
|||||||
R["out_ext"] = run_params.out_ext;
|
R["out_ext"] = run_params.out_ext;
|
||||||
R["out_dir"] = run_params.out_dir;
|
R["out_dir"] = run_params.out_dir;
|
||||||
|
|
||||||
if (run_params.use_ai_surrogate) {
|
if (run_params.ai) {
|
||||||
/* Incorporate ai surrogate from R */
|
/* Incorporate ai surrogate from R */
|
||||||
R.parseEvalQ(ai_surrogate_r_library);
|
R.parseEvalQ(ai_surrogate_r_library);
|
||||||
/* Use dht species for model input and output */
|
/* Use dht species for model input and output */
|
||||||
R["ai_surrogate_species"] =
|
const auto &names = init_list.getChemistryInit().dht_species.getNames();
|
||||||
init_list.getChemistryInit().dht_species.getNames();
|
for (const auto &name : names) {
|
||||||
|
std::cout << name << " ";
|
||||||
|
}
|
||||||
|
std::cout << "\n"; //
|
||||||
const std::string ai_surrogate_input_script =
|
const std::string ai_surrogate_input_script =
|
||||||
init_list.getChemistryInit().ai_surrogate_input_script;
|
init_list.getChemistryInit().ai_surrogate_input_script;
|
||||||
|
|
||||||
@ -627,8 +743,6 @@ int main(int argc, char *argv[]) {
|
|||||||
R.parseEvalQ(ai_surrogate_input_script);
|
R.parseEvalQ(ai_surrogate_input_script);
|
||||||
|
|
||||||
MSG("AI: initialize AI model");
|
MSG("AI: initialize AI model");
|
||||||
R.parseEval("model <- initiate_model()");
|
|
||||||
R.parseEval("gpu_info()");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MSG("Init done on process with rank " + std::to_string(MY_RANK));
|
MSG("Init done on process with rank " + std::to_string(MY_RANK));
|
||||||
|
|||||||
@ -22,12 +22,21 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include <MetaParameter.hpp>
|
||||||
|
#include <Model.hpp>
|
||||||
|
#include <Standardizer.hpp>
|
||||||
|
#include <TrainingBackend.hpp>
|
||||||
|
#include <TrainingData.hpp>
|
||||||
|
|
||||||
#include <Rcpp.h>
|
#include <Rcpp.h>
|
||||||
|
|
||||||
|
using ai_type_t = float;
|
||||||
|
|
||||||
static const char *poet_version = "@POET_VERSION@";
|
static const char *poet_version = "@POET_VERSION@";
|
||||||
|
|
||||||
// using the Raw string literal to avoid escaping the quotes
|
// using the Raw string literal to avoid escaping the quotes
|
||||||
@ -38,6 +47,8 @@ static const inline std::string ai_surrogate_r_library =
|
|||||||
R"(@R_AI_SURROGATE_LIB@)";
|
R"(@R_AI_SURROGATE_LIB@)";
|
||||||
static const inline std::string r_runtime_parameters = "mysetup";
|
static const inline std::string r_runtime_parameters = "mysetup";
|
||||||
|
|
||||||
|
enum BACKEND_TYPE { PYTHON_BACKEND = 1, NAA_BACKEND, CUDA_BACKEND };
|
||||||
|
|
||||||
struct RuntimeParameters {
|
struct RuntimeParameters {
|
||||||
std::string out_dir;
|
std::string out_dir;
|
||||||
std::vector<double> timesteps;
|
std::vector<double> timesteps;
|
||||||
@ -68,5 +79,27 @@ struct RuntimeParameters {
|
|||||||
static constexpr std::uint32_t INTERP_BUCKET_ENTRIES_DEFAULT = 20;
|
static constexpr std::uint32_t INTERP_BUCKET_ENTRIES_DEFAULT = 20;
|
||||||
std::uint32_t interp_bucket_entries = INTERP_BUCKET_ENTRIES_DEFAULT;
|
std::uint32_t interp_bucket_entries = INTERP_BUCKET_ENTRIES_DEFAULT;
|
||||||
|
|
||||||
bool use_ai_surrogate = false;
|
// configuration for ai surrogate approach
|
||||||
|
bool ai = false;
|
||||||
|
bool disable_retraining = false;
|
||||||
|
static constexpr std::uint8_t AI_BACKEND_DEFAULT = 1;
|
||||||
|
std::uint8_t ai_backend = 1; // 1 - python, 2 - naa, 3 - cuda
|
||||||
|
bool train_only_invalid = true;
|
||||||
|
int batch_size = 1000;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct AIContext {
|
||||||
|
TrainingData<ai_type_t> design_buffer;
|
||||||
|
TrainingData<ai_type_t> results_buffer;
|
||||||
|
Model<ai_type_t> model;
|
||||||
|
MetaParameter<ai_type_t> meta_params;
|
||||||
|
Standardizer<ai_type_t> scaler;
|
||||||
|
|
||||||
|
std::binary_semaphore data_semaphore_write{1};
|
||||||
|
std::binary_semaphore data_semaphore_read{0};
|
||||||
|
std::binary_semaphore model_semaphore{1};
|
||||||
|
std::atomic_bool training_is_running = false;
|
||||||
|
std::unique_ptr<TrainingBackend<ai_type_t>> training_backend;
|
||||||
|
|
||||||
|
AIContext(const std::string &weights_path) : model(weights_path) {}
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user