mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-16 04:48:23 +01:00
fix: dynamic training buffer column number
This commit is contained in:
parent
9091117e67
commit
c4da86ef98
@ -612,8 +612,10 @@ void parallel_training(EigenModel* Eigen_model, EigenModel* Eigen_model_reactive
|
||||
// Get the necessary training data
|
||||
std::cout << "AI: Training thread: Getting training data" << std::endl;
|
||||
// Initialize training data input and targets
|
||||
std::vector<std::vector<double>> inputs(9, std::vector<double>(params.training_data_size));
|
||||
std::vector<std::vector<double>> targets(9, std::vector<double>(params.training_data_size));
|
||||
std::vector<std::vector<double>> inputs(training_data_buffer->x.size(),
|
||||
std::vector<double>(params.training_data_size));
|
||||
std::vector<std::vector<double>> targets(training_data_buffer->x.size(),
|
||||
std::vector<double>(params.training_data_size));
|
||||
|
||||
int buffer_size = training_data_buffer->x[0].size();
|
||||
// If clustering is used, check the current cluster
|
||||
@ -676,7 +678,6 @@ void parallel_training(EigenModel* Eigen_model, EigenModel* Eigen_model_reactive
|
||||
Python_Keras_train(inputs, targets, train_cluster, model_name, params);
|
||||
|
||||
if (!params.use_Keras_predictions) {
|
||||
// TODO UPDATE EIGEN MODEL CLUSTER SPECIFIC
|
||||
std::cout << "AI: Training thread: Update shared model weights" << std::endl;
|
||||
std::vector<std::vector<std::vector<double>>> cpp_weights = Python_Keras_get_weights(model_name);
|
||||
Eigen_model_mutex->lock();
|
||||
|
||||
@ -675,7 +675,7 @@ int main(int argc, char *argv[]) {
|
||||
/* Use dht species for model input and output */
|
||||
R["ai_surrogate_species"] =
|
||||
init_list.getChemistryInit().dht_species.getNames();
|
||||
|
||||
R.parseEval("ai_surrogate_species <- ai_surrogate_species[ai_surrogate_species != \"Charge\"]");
|
||||
const std::string ai_surrogate_input_script =
|
||||
init_list.getChemistryInit().ai_surrogate_input_script;
|
||||
|
||||
@ -718,7 +718,7 @@ int main(int argc, char *argv[]) {
|
||||
MSG("K-Means clustering will be used for the AI surrogate")
|
||||
}
|
||||
if (!Rcpp::as<bool>(R.parseEval("exists(\"model_reactive_file_path\")"))) {
|
||||
R.parseEval("model_reactive_file_path <- model_reactive_file_path");
|
||||
R.parseEval("model_reactive_file_path <- model_file_path");
|
||||
}
|
||||
MSG("AI: Initialize Python for the AI surrogate functions");
|
||||
std::string python_keras_file = std::string(SRC_DIR) +
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user