This commit is contained in:
straile 2024-10-15 11:36:12 +02:00
parent 8691370abb
commit a4a1eedcac
3 changed files with 36 additions and 39 deletions

View File

@ -182,51 +182,48 @@ Eigen::MatrixXd eigen_inference_batched(const Eigen::Ref<const Eigen::MatrixXd>&
* @return Predictions that the neural network made from the input values x. The predictions are
* represented as a vector similar to the representation from the Field.AsVector() method
*/
std::vector<double> Eigen_predict(const EigenModel& model, const std::vector<std::vector<double>>& x, int batch_size,
std::vector<double> Eigen_predict(const EigenModel& model, std::vector<std::vector<double>> x, int batch_size,
std::mutex* Eigen_model_mutex) {
// Convert input data to Eigen matrix
const int num_samples = x[0].size();
const int num_features = x.size();
Eigen::MatrixXd full_input_matrix(num_features, num_samples);
for (int i = 0; i < num_samples; ++i) {
for (int j = 0; j < num_features; ++j) {
full_input_matrix(j, i) = x[j][i];
}
// Convert input data to Eigen matrix
const int num_samples = x[0].size();
const int num_features = x.size();
Eigen::MatrixXd full_input_matrix(num_features, num_samples);
for (int i = 0; i < num_samples; ++i) {
for (int j = 0; j < num_features; ++j) {
full_input_matrix(j, i) = x[j][i];
}
}
std::vector<double> result;
result.reserve(num_samples);
std::vector<double> result;
result.reserve(num_samples * num_features);
if (num_features != model.weight_matrices[0].cols()) {
throw std::runtime_error("Input data size " + std::to_string(num_features) + \
" does not match model input layer of size " + std::to_string(model.weight_matrices[0].cols()));
}
int num_batches = std::ceil(static_cast<double>(num_samples) / batch_size);
if (num_features != model.weight_matrices[0].cols()) {
throw std::runtime_error("Input data size " + std::to_string(num_features) +
" does not match model input layer of size " +
std::to_string(model.weight_matrices[0].cols()));
}
Eigen_model_mutex->lock();
for (int batch = 0; batch < num_batches; ++batch) {
int start_idx = batch * batch_size;
int end_idx = std::min((batch + 1) * batch_size, num_samples);
int current_batch_size = end_idx - start_idx;
// Extract the current input data batch
Eigen::MatrixXd batch_data(num_features, current_batch_size);
batch_data = full_input_matrix.block(0, start_idx, num_features, current_batch_size);
// Predict
batch_data = eigen_inference_batched(batch_data, model);
int num_batches = std::ceil(static_cast<double>(num_samples) / batch_size);
std::lock_guard<std::mutex> lock(Eigen_model_mutex);
for (int batch = 0; batch < num_batches; ++batch) {
int start_idx = batch * batch_size;
int end_idx = std::min((batch + 1) * batch_size, num_samples);
int current_batch_size = end_idx - start_idx;
// Extract the current input data batch
Eigen::MatrixXd batch_data = full_input_matrix.block(0, start_idx, num_features, current_batch_size);
// Predict
Eigen::MatrixXd output = eigen_inference_batched(batch_data, model);
// Append the results
result.insert(result.end(), output.data(), output.data() + output.size());
}
return result;
result.insert(result.end(), batch_data.data(), batch_data.data() + batch_data.size());
}
Eigen_model_mutex->unlock();
return result;
}
/**
* @brief Appends data from one matrix (column major std::vector<std::vector<double>>) to another
* @param training_data_buffer Matrix that the values are appended to

View File

@ -64,7 +64,7 @@ void update_weights(EigenModel* model, const std::vector<std::vector<std::vector
std::vector<std::vector<std::vector<double>>> Python_Keras_get_weights();
std::vector<double> Eigen_predict(const EigenModel& model, std::vector<std::vector<double>> x, int batch_size,
std::mutex* Eigen_model_mutex);
std::mutex& Eigen_model_mutex);
// Otherwise, define the necessary stubs
#else
@ -80,7 +80,7 @@ inline int Python_Keras_training_thread(EigenModel*, std::mutex*,
inline void update_weights(EigenModel*, const std::vector<std::vector<std::vector<double>>>&){return {};}
inline std::vector<std::vector<std::vector<double>>> Python_Keras_get_weights(){return {};}
inline std::vector<double> Eigen_predict(const EigenModel&, std::vector<std::vector<double>>, int, std::mutex*){return {};}
inline std::vector<double> Eigen_predict(const EigenModel&, std::vector<std::vector<double>>, int, std::mutex&){return {};}
#endif
} // namespace poet

View File

@ -365,7 +365,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
R["TMP"] = Python_Keras_predict(R["predictors_scaled"], params.batch_size);
} else { // Predict with custom Eigen function
R["TMP"] = Eigen_predict(Eigen_model, R["predictors_scaled"], params.batch_size, &Eigen_model_mutex);
R["TMP"] = Eigen_predict(Eigen_model, R["predictors_scaled"], params.batch_size, Eigen_model_mutex);
}
// Apply postprocessing