docs: add info about ReLU and C++ inference

This commit is contained in:
straile 2024-11-08 19:18:17 +01:00
parent 8adeffa315
commit f648f618de
2 changed files with 9 additions and 4 deletions

View File

@ -276,8 +276,12 @@ the buffer has been filled, the model starts training and removes this amount
of data from the front of the buffer. Defaults to the size of the Field. of data from the front of the buffer. Defaults to the size of the Field.
- `use_Keras_predictions` [*bool*]: Decides if the Keras prediction function - `use_Keras_predictions` [*bool*]: Decides if the Keras prediction function
should be used instead of the custom C++ implementation (Keras might be faster should be used instead of the custom C++ implementation. Keras might be faster
for larger models, especially on GPU). Defaults to false. for larger models, especially on GPU. The C++ inference function assumes that
the Keras model is a standrad feed forward network with either 32 or 64 bit
precision and ReLU activation. Any model that deviates from this architecture
should activate the Keras prediction function to ensure correct calculation.
Defaults to false.
- `disable_training` [*bool*]: Deactivates the training functions. Defaults to - `disable_training` [*bool*]: Deactivates the training functions. Defaults to
false. false.

View File

@ -178,7 +178,8 @@ std::vector<double> Python_Keras_predict(std::vector<std::vector<double>>& x, in
} }
/** /**
* @brief Uses Eigen for fast inference with the weights and biases of a neural network * @brief Uses Eigen for fast inference with the weights and biases of a neural network.
* This function assumes ReLU activation for each layer.
* @param input_batch Batch of input data that must fit the size of the neural networks input layer * @param input_batch Batch of input data that must fit the size of the neural networks input layer
* @param model Struct of aligned Eigen vectors that hold the neural networks weights and biases. * @param model Struct of aligned Eigen vectors that hold the neural networks weights and biases.
* Only supports simple fully connected feed forward networks. * Only supports simple fully connected feed forward networks.
@ -607,7 +608,7 @@ int Python_Keras_training_thread(EigenModel* Eigen_model, EigenModel* Eigen_mode
/** /**
* @brief Updates the EigenModels weigths and biases from the weight vector * @brief Updates the EigenModels weigths and biases from the weight vector
* @param model Pounter to an EigenModel struct * @param model Pounter to an EigenModel struct
* @param weights Cector of model weights from keras as returned by Python_Keras_get_weights() * @param weights Vector of model weights from keras as returned by Python_Keras_get_weights()
*/ */
void update_weights(EigenModel* model, void update_weights(EigenModel* model,
const std::vector<std::vector<std::vector<double>>>& weights) { const std::vector<std::vector<std::vector<double>>>& weights) {