fix: set training wait predicate with buffer threshold check

This commit is contained in:
straile 2024-10-20 12:38:58 +02:00
parent 997ae32092
commit 4d254250e1
2 changed files with 4 additions and 6 deletions

View File

@ -5,6 +5,7 @@ import os
def initiate_model(model_file_path, cuda_dir):
os.environ["TF_XLA_FLAGS"] = "--tf_xla_cpu_global_jit"
os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=" + cuda_dir
print("AI: Model loaded from: " + model_file_path, flush=True)
model = tf.keras.models.load_model(model_file_path)
return model

View File

@ -310,16 +310,11 @@ void parallel_training(EigenModel* Eigen_model,
// wait for a signal on training_data_buffer_full but starts the next round immediately.
std::unique_lock<std::mutex> lock(*training_data_buffer_mutex);
training_data_buffer_full->wait(lock, [start_training] { return *start_training;});
//hier nochmal training_data_buffer_mutex lock/lock test?
// Return if program is about to end
if (*end_training) {
return;
}
// Reset the waiting predicate
*start_training = false;
// Get the necessary training data
std::cout << "AI: Training thread: Getting training data" << std::endl;
// Initialize training data input and targets
@ -342,6 +337,8 @@ void parallel_training(EigenModel* Eigen_model,
training_data_buffer->y[col].erase(training_data_buffer->y[col].begin(),
training_data_buffer->y[col].begin() + params.training_data_size);
}
// Set the waiting predicate to false if buffer is below threshold
*start_training = training_data_buffer->y[0].size() >= params.training_data_size;
//update number of training runs
training_data_buffer->n_training_runs += 1;
// Unlock the training_data_buffer_mutex
@ -366,7 +363,7 @@ void parallel_training(EigenModel* Eigen_model,
// Release the Python GIL
PyGILState_Release(gstate);
std::cout << "AI: Training thread: Finished training, waiting for new data" << std::endl;
}
}
}
std::thread python_train_thread;