fix: dynamic training buffer column number

This commit is contained in:
straile 2024-11-03 18:02:32 +01:00
parent 9091117e67
commit c4da86ef98
2 changed files with 6 additions and 5 deletions

View File

@ -612,8 +612,10 @@ void parallel_training(EigenModel* Eigen_model, EigenModel* Eigen_model_reactive
// Get the necessary training data // Get the necessary training data
std::cout << "AI: Training thread: Getting training data" << std::endl; std::cout << "AI: Training thread: Getting training data" << std::endl;
// Initialize training data input and targets // Initialize training data input and targets
std::vector<std::vector<double>> inputs(9, std::vector<double>(params.training_data_size)); std::vector<std::vector<double>> inputs(training_data_buffer->x.size(),
std::vector<std::vector<double>> targets(9, std::vector<double>(params.training_data_size)); std::vector<double>(params.training_data_size));
std::vector<std::vector<double>> targets(training_data_buffer->x.size(),
std::vector<double>(params.training_data_size));
int buffer_size = training_data_buffer->x[0].size(); int buffer_size = training_data_buffer->x[0].size();
// If clustering is used, check the current cluster // If clustering is used, check the current cluster
@ -676,7 +678,6 @@ void parallel_training(EigenModel* Eigen_model, EigenModel* Eigen_model_reactive
Python_Keras_train(inputs, targets, train_cluster, model_name, params); Python_Keras_train(inputs, targets, train_cluster, model_name, params);
if (!params.use_Keras_predictions) { if (!params.use_Keras_predictions) {
// TODO UPDATE EIGEN MODEL CLUSTER SPECIFIC
std::cout << "AI: Training thread: Update shared model weights" << std::endl; std::cout << "AI: Training thread: Update shared model weights" << std::endl;
std::vector<std::vector<std::vector<double>>> cpp_weights = Python_Keras_get_weights(model_name); std::vector<std::vector<std::vector<double>>> cpp_weights = Python_Keras_get_weights(model_name);
Eigen_model_mutex->lock(); Eigen_model_mutex->lock();

View File

@ -675,7 +675,7 @@ int main(int argc, char *argv[]) {
/* Use dht species for model input and output */ /* Use dht species for model input and output */
R["ai_surrogate_species"] = R["ai_surrogate_species"] =
init_list.getChemistryInit().dht_species.getNames(); init_list.getChemistryInit().dht_species.getNames();
R.parseEval("ai_surrogate_species <- ai_surrogate_species[ai_surrogate_species != \"Charge\"]");
const std::string ai_surrogate_input_script = const std::string ai_surrogate_input_script =
init_list.getChemistryInit().ai_surrogate_input_script; init_list.getChemistryInit().ai_surrogate_input_script;
@ -718,7 +718,7 @@ int main(int argc, char *argv[]) {
MSG("K-Means clustering will be used for the AI surrogate") MSG("K-Means clustering will be used for the AI surrogate")
} }
if (!Rcpp::as<bool>(R.parseEval("exists(\"model_reactive_file_path\")"))) { if (!Rcpp::as<bool>(R.parseEval("exists(\"model_reactive_file_path\")"))) {
R.parseEval("model_reactive_file_path <- model_reactive_file_path"); R.parseEval("model_reactive_file_path <- model_file_path");
} }
MSG("AI: Initialize Python for the AI surrogate functions"); MSG("AI: Initialize Python for the AI surrogate functions");
std::string python_keras_file = std::string(SRC_DIR) + std::string python_keras_file = std::string(SRC_DIR) +