mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-15 20:38:23 +01:00
fix: set training wait predicate with buffer threshold check
This commit is contained in:
parent
997ae32092
commit
4d254250e1
@ -5,6 +5,7 @@ import os
|
||||
def initiate_model(model_file_path, cuda_dir):
|
||||
os.environ["TF_XLA_FLAGS"] = "--tf_xla_cpu_global_jit"
|
||||
os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=" + cuda_dir
|
||||
print("AI: Model loaded from: " + model_file_path, flush=True)
|
||||
model = tf.keras.models.load_model(model_file_path)
|
||||
return model
|
||||
|
||||
|
||||
@ -310,16 +310,11 @@ void parallel_training(EigenModel* Eigen_model,
|
||||
// wait for a signal on training_data_buffer_full but starts the next round immediately.
|
||||
std::unique_lock<std::mutex> lock(*training_data_buffer_mutex);
|
||||
training_data_buffer_full->wait(lock, [start_training] { return *start_training;});
|
||||
|
||||
//hier nochmal training_data_buffer_mutex lock/lock test?
|
||||
|
||||
// Return if program is about to end
|
||||
if (*end_training) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Reset the waiting predicate
|
||||
*start_training = false;
|
||||
// Get the necessary training data
|
||||
std::cout << "AI: Training thread: Getting training data" << std::endl;
|
||||
// Initialize training data input and targets
|
||||
@ -342,6 +337,8 @@ void parallel_training(EigenModel* Eigen_model,
|
||||
training_data_buffer->y[col].erase(training_data_buffer->y[col].begin(),
|
||||
training_data_buffer->y[col].begin() + params.training_data_size);
|
||||
}
|
||||
// Set the waiting predicate to false if buffer is below threshold
|
||||
*start_training = training_data_buffer->y[0].size() >= params.training_data_size;
|
||||
//update number of training runs
|
||||
training_data_buffer->n_training_runs += 1;
|
||||
// Unlock the training_data_buffer_mutex
|
||||
@ -366,7 +363,7 @@ void parallel_training(EigenModel* Eigen_model,
|
||||
// Release the Python GIL
|
||||
PyGILState_Release(gstate);
|
||||
std::cout << "AI: Training thread: Finished training, waiting for new data" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::thread python_train_thread;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user