mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-15 20:38:23 +01:00
what is the problem?
This commit is contained in:
parent
74cd827c68
commit
c323705f34
@ -205,17 +205,11 @@ std::vector<double> Python_Keras_predict(std::vector<std::vector<double>> x, int
|
||||
*/
|
||||
std::vector<double> Eigen_predict(const EigenModel& model, std::vector<std::vector<double>> x, int batch_size,
|
||||
std::mutex* Eigen_model_mutex) {
|
||||
|
||||
std::cout << "GETTING DIMS" << std::endl;
|
||||
|
||||
// Convert input data to Eigen matrix
|
||||
const int num_samples = x[0].size();
|
||||
const int num_features = x.size();
|
||||
|
||||
std::cout << "SETTING MATRIX" << num_samples << num_features << std::endl;
|
||||
Eigen::MatrixXd full_input_matrix(num_features, num_samples);
|
||||
|
||||
std::cout << "SETTING VALUES" << std::endl;
|
||||
for (int i = 0; i < num_samples; ++i) {
|
||||
for (int j = 0; j < num_features; ++j) {
|
||||
full_input_matrix(j, i) = x[j][i];
|
||||
@ -223,38 +217,27 @@ std::vector<double> Eigen_predict(const EigenModel& model, std::vector<std::vect
|
||||
}
|
||||
|
||||
std::vector<double> result;
|
||||
std::cout << "RESERVING RESULT" << std::endl;
|
||||
|
||||
result.reserve(num_samples * num_features);
|
||||
|
||||
if (num_features != model.weight_matrices[0].cols()) {
|
||||
throw std::runtime_error("Input data size " + std::to_string(num_features) + \
|
||||
" does not match model input layer of size " + std::to_string(model.weight_matrices[0].cols()));
|
||||
}
|
||||
std::cout << "MAKING BATCHESS OF SIZE " << batch_size << std::endl;
|
||||
int num_batches = std::ceil(static_cast<double>(num_samples) / batch_size);
|
||||
std::cout << "LOOKING MUTEX"<< std::endl;
|
||||
|
||||
Eigen_model_mutex->lock();
|
||||
std::cout << "STARTING CALCULATIONS"<< std::endl;
|
||||
for (int batch = 0; batch < num_batches; ++batch) {
|
||||
std::cout << "BATCH "<< batch << std::endl;
|
||||
int start_idx = batch * batch_size;
|
||||
int end_idx = std::min((batch + 1) * batch_size, num_samples);
|
||||
int current_batch_size = end_idx - start_idx;
|
||||
// Extract the current input data batch
|
||||
std::cout << "BATCH SIZE CALCULATED "<< batch << std::endl;
|
||||
Eigen::MatrixXd batch_data(num_features, current_batch_size);
|
||||
std::cout << "BATCH INPUT DECLARED "<< batch << std::endl;
|
||||
batch_data = full_input_matrix.block(0, start_idx, num_features, current_batch_size);
|
||||
std::cout << "BATCH INPUT SET "<< batch << std::endl;
|
||||
// Predict
|
||||
std::cout << "BATCH INPUT CLCULATE "<< batch << std::endl;
|
||||
batch_data = eigen_inference_batched(batch_data, model);
|
||||
std::cout << "RESULT INSERT "<< batch << std::endl;
|
||||
|
||||
result.insert(result.end(), batch_data.data(), batch_data.data() + batch_data.size());
|
||||
}
|
||||
std::cout << "UNLOCKING MUTEX"<< std::endl;
|
||||
Eigen_model_mutex->unlock();
|
||||
return result;
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user