mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-16 04:48:23 +01:00
fix: dynamic training buffer column number
This commit is contained in:
parent
9091117e67
commit
c4da86ef98
@ -612,8 +612,10 @@ void parallel_training(EigenModel* Eigen_model, EigenModel* Eigen_model_reactive
|
|||||||
// Get the necessary training data
|
// Get the necessary training data
|
||||||
std::cout << "AI: Training thread: Getting training data" << std::endl;
|
std::cout << "AI: Training thread: Getting training data" << std::endl;
|
||||||
// Initialize training data input and targets
|
// Initialize training data input and targets
|
||||||
std::vector<std::vector<double>> inputs(9, std::vector<double>(params.training_data_size));
|
std::vector<std::vector<double>> inputs(training_data_buffer->x.size(),
|
||||||
std::vector<std::vector<double>> targets(9, std::vector<double>(params.training_data_size));
|
std::vector<double>(params.training_data_size));
|
||||||
|
std::vector<std::vector<double>> targets(training_data_buffer->x.size(),
|
||||||
|
std::vector<double>(params.training_data_size));
|
||||||
|
|
||||||
int buffer_size = training_data_buffer->x[0].size();
|
int buffer_size = training_data_buffer->x[0].size();
|
||||||
// If clustering is used, check the current cluster
|
// If clustering is used, check the current cluster
|
||||||
@ -676,7 +678,6 @@ void parallel_training(EigenModel* Eigen_model, EigenModel* Eigen_model_reactive
|
|||||||
Python_Keras_train(inputs, targets, train_cluster, model_name, params);
|
Python_Keras_train(inputs, targets, train_cluster, model_name, params);
|
||||||
|
|
||||||
if (!params.use_Keras_predictions) {
|
if (!params.use_Keras_predictions) {
|
||||||
// TODO UPDATE EIGEN MODEL CLUSTER SPECIFIC
|
|
||||||
std::cout << "AI: Training thread: Update shared model weights" << std::endl;
|
std::cout << "AI: Training thread: Update shared model weights" << std::endl;
|
||||||
std::vector<std::vector<std::vector<double>>> cpp_weights = Python_Keras_get_weights(model_name);
|
std::vector<std::vector<std::vector<double>>> cpp_weights = Python_Keras_get_weights(model_name);
|
||||||
Eigen_model_mutex->lock();
|
Eigen_model_mutex->lock();
|
||||||
|
|||||||
@ -675,7 +675,7 @@ int main(int argc, char *argv[]) {
|
|||||||
/* Use dht species for model input and output */
|
/* Use dht species for model input and output */
|
||||||
R["ai_surrogate_species"] =
|
R["ai_surrogate_species"] =
|
||||||
init_list.getChemistryInit().dht_species.getNames();
|
init_list.getChemistryInit().dht_species.getNames();
|
||||||
|
R.parseEval("ai_surrogate_species <- ai_surrogate_species[ai_surrogate_species != \"Charge\"]");
|
||||||
const std::string ai_surrogate_input_script =
|
const std::string ai_surrogate_input_script =
|
||||||
init_list.getChemistryInit().ai_surrogate_input_script;
|
init_list.getChemistryInit().ai_surrogate_input_script;
|
||||||
|
|
||||||
@ -718,7 +718,7 @@ int main(int argc, char *argv[]) {
|
|||||||
MSG("K-Means clustering will be used for the AI surrogate")
|
MSG("K-Means clustering will be used for the AI surrogate")
|
||||||
}
|
}
|
||||||
if (!Rcpp::as<bool>(R.parseEval("exists(\"model_reactive_file_path\")"))) {
|
if (!Rcpp::as<bool>(R.parseEval("exists(\"model_reactive_file_path\")"))) {
|
||||||
R.parseEval("model_reactive_file_path <- model_reactive_file_path");
|
R.parseEval("model_reactive_file_path <- model_file_path");
|
||||||
}
|
}
|
||||||
MSG("AI: Initialize Python for the AI surrogate functions");
|
MSG("AI: Initialize Python for the AI surrogate functions");
|
||||||
std::string python_keras_file = std::string(SRC_DIR) +
|
std::string python_keras_file = std::string(SRC_DIR) +
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user