From f648f618de193c7bcce587fe72c350aa485a72cc Mon Sep 17 00:00:00 2001 From: straile Date: Fri, 8 Nov 2024 19:18:17 +0100 Subject: [PATCH] docs: add info about ReLU and C++ inference --- README.md | 8 ++++++-- src/Chemistry/SurrogateModels/AI_functions.cpp | 5 +++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index f02a8679c..dba14f7e5 100644 --- a/README.md +++ b/README.md @@ -276,8 +276,12 @@ the buffer has been filled, the model starts training and removes this amount of data from the front of the buffer. Defaults to the size of the Field. - `use_Keras_predictions` [*bool*]: Decides if the Keras prediction function -should be used instead of the custom C++ implementation (Keras might be faster -for larger models, especially on GPU). Defaults to false. +should be used instead of the custom C++ implementation. Keras might be faster +for larger models, especially on GPU. The C++ inference function assumes that +the Keras model is a standrad feed forward network with either 32 or 64 bit +precision and ReLU activation. Any model that deviates from this architecture +should activate the Keras prediction function to ensure correct calculation. +Defaults to false. - `disable_training` [*bool*]: Deactivates the training functions. Defaults to false. diff --git a/src/Chemistry/SurrogateModels/AI_functions.cpp b/src/Chemistry/SurrogateModels/AI_functions.cpp index c966c1a4b..4b40b9107 100644 --- a/src/Chemistry/SurrogateModels/AI_functions.cpp +++ b/src/Chemistry/SurrogateModels/AI_functions.cpp @@ -178,7 +178,8 @@ std::vector Python_Keras_predict(std::vector>& x, in } /** - * @brief Uses Eigen for fast inference with the weights and biases of a neural network + * @brief Uses Eigen for fast inference with the weights and biases of a neural network. + * This function assumes ReLU activation for each layer. * @param input_batch Batch of input data that must fit the size of the neural networks input layer * @param model Struct of aligned Eigen vectors that hold the neural networks weights and biases. * Only supports simple fully connected feed forward networks. @@ -607,7 +608,7 @@ int Python_Keras_training_thread(EigenModel* Eigen_model, EigenModel* Eigen_mode /** * @brief Updates the EigenModels weigths and biases from the weight vector * @param model Pounter to an EigenModel struct - * @param weights Cector of model weights from keras as returned by Python_Keras_get_weights() + * @param weights Vector of model weights from keras as returned by Python_Keras_get_weights() */ void update_weights(EigenModel* model, const std::vector>>& weights) {