mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-16 12:54:50 +01:00
run retraining only once
This commit is contained in:
parent
d825f33b4f
commit
c7d1fc152c
20
src/poet.cpp
20
src/poet.cpp
@ -310,6 +310,8 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
|||||||
R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps());
|
R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps());
|
||||||
|
|
||||||
std::unique_ptr<AIContext> ai_ctx = nullptr;
|
std::unique_ptr<AIContext> ai_ctx = nullptr;
|
||||||
|
size_t retrain_counter = 0;
|
||||||
|
size_t field_size = 0;
|
||||||
|
|
||||||
if (params.ai) {
|
if (params.ai) {
|
||||||
|
|
||||||
@ -329,6 +331,9 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
|||||||
std::vector<float> mean = R["mean"];
|
std::vector<float> mean = R["mean"];
|
||||||
std::vector<float> scale = R["scale"];
|
std::vector<float> scale = R["scale"];
|
||||||
|
|
||||||
|
field_size = chem.getField().GetRequestedVecSize();
|
||||||
|
std::cout << field_size << std::endl;
|
||||||
|
|
||||||
ai_ctx->scaler.set_scaler(mean, scale);
|
ai_ctx->scaler.set_scaler(mean, scale);
|
||||||
|
|
||||||
// initialzie training backens only if retraining is desired
|
// initialzie training backens only if retraining is desired
|
||||||
@ -338,7 +343,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
|||||||
} else if (params.ai_backend == NAA_BACKEND) {
|
} else if (params.ai_backend == NAA_BACKEND) {
|
||||||
MSG("AI Surrogate with NAA backend enabled.")
|
MSG("AI Surrogate with NAA backend enabled.")
|
||||||
ai_ctx->training_backend =
|
ai_ctx->training_backend =
|
||||||
std::make_unique<NAABackend<ai_type_t>>(4 * params.batch_size);
|
std::make_unique<NAABackend<ai_type_t>>(4 * field_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!params.disable_retraining) {
|
if (!params.disable_retraining) {
|
||||||
@ -535,15 +540,15 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
|||||||
|
|
||||||
MSG("AI: add invalid data to buffer");
|
MSG("AI: add invalid data to buffer");
|
||||||
|
|
||||||
ai_ctx->data_semaphore_write.acquire();
|
|
||||||
|
|
||||||
std::cout << "size of predictors " << predictors_retraining[0].size()
|
std::cout << "size of predictors " << predictors_retraining[0].size()
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
std::cout << "size of targets " << targets_retraining[0].size()
|
std::cout << "size of targets " << targets_retraining[0].size()
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
|
|
||||||
|
ai_ctx->data_semaphore_write.acquire();
|
||||||
|
|
||||||
if (predictors_retraining[0].size() > 0 &&
|
if (predictors_retraining[0].size() > 0 &&
|
||||||
targets_retraining[0].size() > 0) {
|
targets_retraining[0].size() > 0 && retrain_counter == 0) {
|
||||||
ai_ctx->design_buffer.addData(predictors_retraining);
|
ai_ctx->design_buffer.addData(predictors_retraining);
|
||||||
ai_ctx->results_buffer.addData(targets_retraining);
|
ai_ctx->results_buffer.addData(targets_retraining);
|
||||||
}
|
}
|
||||||
@ -559,10 +564,11 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms,
|
|||||||
std::cout << "results_buffer_size: " << elements_results_buffer
|
std::cout << "results_buffer_size: " << elements_results_buffer
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
|
|
||||||
if (elements_design_buffer >= 4 * params.batch_size &&
|
if (elements_design_buffer >= 4 * field_size &&
|
||||||
elements_results_buffer >= 4 * params.batch_size &&
|
elements_results_buffer >= 4 * field_size &&
|
||||||
ai_ctx->training_is_running == false) {
|
ai_ctx->training_is_running == false && retrain_counter == 0) {
|
||||||
ai_ctx->data_semaphore_read.release();
|
ai_ctx->data_semaphore_read.release();
|
||||||
|
retrain_counter++;
|
||||||
} else {
|
} else {
|
||||||
ai_ctx->data_semaphore_write.release();
|
ai_ctx->data_semaphore_write.release();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -23,6 +23,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
@ -86,7 +87,7 @@ struct RuntimeParameters {
|
|||||||
static constexpr std::uint8_t AI_BACKEND_DEFAULT = 1;
|
static constexpr std::uint8_t AI_BACKEND_DEFAULT = 1;
|
||||||
std::uint8_t ai_backend = AI_BACKEND_DEFAULT; // 1 - python, 2 - naa
|
std::uint8_t ai_backend = AI_BACKEND_DEFAULT; // 1 - python, 2 - naa
|
||||||
bool train_only_invalid = true;
|
bool train_only_invalid = true;
|
||||||
int batch_size = 200 * 200;
|
int batch_size = 2500;
|
||||||
static constexpr std::uint8_t DEFAULT_FUNCTION_CODE = 0;
|
static constexpr std::uint8_t DEFAULT_FUNCTION_CODE = 0;
|
||||||
std::uint8_t function_code = DEFAULT_FUNCTION_CODE;
|
std::uint8_t function_code = DEFAULT_FUNCTION_CODE;
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user