feat(control): dynamic prototype, penalty_iteration, error while disabling surrogate fixed

This commit is contained in:
rastogi 2025-10-02 13:20:53 +02:00 committed by Max Lübke
parent 41d1a9895c
commit 15e397ecf2
6 changed files with 171 additions and 114 deletions

View File

@ -28,12 +28,12 @@ if (POET_PREPROCESS_BENCHS)
endif() endif()
# as tug will also pull in doctest as a dependency # as tug will also pull in doctest as a dependency
set(TUG_ENABLE_TESTING ON CACHE BOOL "" FORCE) set(TUG_ENABLE_TESTING OFF CACHE BOOL "" FORCE)
add_subdirectory(ext/tug EXCLUDE_FROM_ALL) add_subdirectory(ext/tug EXCLUDE_FROM_ALL)
add_subdirectory(ext/iphreeqc EXCLUDE_FROM_ALL) add_subdirectory(ext/iphreeqc EXCLUDE_FROM_ALL)
option(POET_ENABLE_TESTING "Build test suite for POET" ON) option(POET_ENABLE_TESTING "Build test suite for POET" OFF)
if (POET_ENABLE_TESTING) if (POET_ENABLE_TESTING)
add_subdirectory(test) add_subdirectory(test)

View File

@ -299,6 +299,7 @@ namespace poet
CHEM_DHT_SIGNIF_VEC, CHEM_DHT_SIGNIF_VEC,
CHEM_DHT_SNAPS, CHEM_DHT_SNAPS,
CHEM_DHT_READ_FILE, CHEM_DHT_READ_FILE,
CHEM_INTERP,
CHEM_IP_ENABLE, CHEM_IP_ENABLE,
CHEM_IP_MIN_ENTRIES, CHEM_IP_MIN_ENTRIES,
CHEM_IP_SIGNIF_VEC, CHEM_IP_SIGNIF_VEC,

View File

@ -448,6 +448,19 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
ftype = CHEM_WORK_LOOP; ftype = CHEM_WORK_LOOP;
PropagateFunctionType(ftype); PropagateFunctionType(ftype);
ftype = CHEM_INTERP;
PropagateFunctionType(ftype);
if(this->runtime_params->rollback_simulation){
this->interp_enabled = false;
int interp_flag = 0;
ChemBCast(&interp_flag, 1, MPI_INT);
} else {
this->interp_enabled = true;
int interp_flag = 1;
ChemBCast(&interp_flag, 1, MPI_INT);
}
MPI_Barrier(this->group_comm); MPI_Barrier(this->group_comm);
static uint32_t iteration = 0; static uint32_t iteration = 0;

View File

@ -34,105 +34,112 @@ namespace poet
return ret_str; return ret_str;
} }
void poet::ChemistryModule::WorkerLoop() void poet::ChemistryModule::WorkerLoop()
{
struct worker_s timings;
// HACK: defining the worker iteration count here, which will increment after
// each CHEM_ITER_END message
uint32_t iteration = 1;
bool loop = true;
while (loop)
{ {
int func_type; struct worker_s timings;
PropagateFunctionType(func_type);
switch (func_type) // HACK: defining the worker iteration count here, which will increment after
// each CHEM_ITER_END message
uint32_t iteration = 1;
bool loop = true;
while (loop)
{ {
case CHEM_FIELD_INIT: int func_type;
{ PropagateFunctionType(func_type);
ChemBCast(&this->prop_count, 1, MPI_UINT32_T);
if (this->ai_surrogate_enabled) switch (func_type)
{ {
this->ai_surrogate_validity_vector.resize( case CHEM_FIELD_INIT:
this->n_cells); // resize statt reserve?
}
break;
}
case CHEM_AI_BCAST_VALIDITY:
{
// Receive the index vector of valid ai surrogate predictions
MPI_Bcast(&this->ai_surrogate_validity_vector.front(), this->n_cells,
MPI_INT, 0, this->group_comm);
break;
}
case CHEM_WORK_LOOP:
{
WorkerProcessPkgs(timings, iteration);
break;
}
case CHEM_PERF:
{
int type;
ChemBCast(&type, 1, MPI_INT);
if (type < WORKER_DHT_HITS)
{ {
WorkerPerfToMaster(type, timings); ChemBCast(&this->prop_count, 1, MPI_UINT32_T);
if (this->ai_surrogate_enabled)
{
this->ai_surrogate_validity_vector.resize(
this->n_cells); // resize statt reserve?
}
break; break;
} }
WorkerMetricsToMaster(type); case CHEM_AI_BCAST_VALIDITY:
break; {
} // Receive the index vector of valid ai surrogate predictions
case CHEM_BREAK_MAIN_LOOP: MPI_Bcast(&this->ai_surrogate_validity_vector.front(), this->n_cells,
{ MPI_INT, 0, this->group_comm);
WorkerPostSim(iteration); break;
loop = false; }
break; case CHEM_INTERP:
} {
default: int interp_flag;
{ ChemBCast(&interp_flag, 1, MPI_INT);
throw std::runtime_error("Worker received unknown tag from master."); this->interp_enabled = (interp_flag == 1);
} break;
}
case CHEM_WORK_LOOP:
{
WorkerProcessPkgs(timings, iteration);
break;
}
case CHEM_PERF:
{
int type;
ChemBCast(&type, 1, MPI_INT);
if (type < WORKER_DHT_HITS)
{
WorkerPerfToMaster(type, timings);
break;
}
WorkerMetricsToMaster(type);
break;
}
case CHEM_BREAK_MAIN_LOOP:
{
WorkerPostSim(iteration);
loop = false;
break;
}
default:
{
throw std::runtime_error("Worker received unknown tag from master.");
}
}
} }
} }
}
void poet::ChemistryModule::WorkerProcessPkgs(struct worker_s &timings, void poet::ChemistryModule::WorkerProcessPkgs(struct worker_s &timings,
uint32_t &iteration) uint32_t &iteration)
{
MPI_Status probe_status;
bool loop = true;
MPI_Barrier(this->group_comm);
while (loop)
{ {
double idle_a = MPI_Wtime(); MPI_Status probe_status;
MPI_Probe(0, MPI_ANY_TAG, this->group_comm, &probe_status); bool loop = true;
double idle_b = MPI_Wtime();
switch (probe_status.MPI_TAG) MPI_Barrier(this->group_comm);
{
case LOOP_WORK:
{
timings.idle_t += idle_b - idle_a;
int count;
MPI_Get_count(&probe_status, MPI_DOUBLE, &count);
WorkerDoWork(probe_status, count, timings); while (loop)
break;
}
case LOOP_END:
{ {
WorkerPostIter(probe_status, iteration); double idle_a = MPI_Wtime();
iteration++; MPI_Probe(0, MPI_ANY_TAG, this->group_comm, &probe_status);
loop = false; double idle_b = MPI_Wtime();
break;
} switch (probe_status.MPI_TAG)
{
case LOOP_WORK:
{
timings.idle_t += idle_b - idle_a;
int count;
MPI_Get_count(&probe_status, MPI_DOUBLE, &count);
WorkerDoWork(probe_status, count, timings);
break;
}
case LOOP_END:
{
WorkerPostIter(probe_status, iteration);
iteration++;
loop = false;
break;
}
}
} }
} }
}
void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status, void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
int double_count, int double_count,
@ -254,7 +261,7 @@ namespace poet
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++)
{ {
if (!s_curr_wp.mapping[wp_i] == CHEM_PQC) // only copy if surrogate was used if (s_curr_wp.mapping[wp_i] != CHEM_PQC) // only copy if surrogate was used
{ {
std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(), std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(),
mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i); mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i);

View File

@ -270,6 +270,10 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params)
Rcpp::as<uint32_t>(global_rt_setup->operator[]("control_iteration")); Rcpp::as<uint32_t>(global_rt_setup->operator[]("control_iteration"));
params.species_epsilon = params.species_epsilon =
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("species_epsilon")); Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("species_epsilon"));
params.penalty_iteration =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("penalty_iteration"));
params.max_penalty_iteration =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("max_penalty_iteration"));
} }
catch (const std::exception &e) catch (const std::exception &e)
{ {
@ -302,31 +306,50 @@ void call_master_iter_end(RInside &R, const Field &trans, const Field &chem)
bool checkAndRollback(ChemistryModule &chem, RuntimeParameters &params, uint32_t &iter) bool checkAndRollback(ChemistryModule &chem, RuntimeParameters &params, uint32_t &iter)
{ {
for (uint32_t i = 0; i < chem.error_stats_history.size(); i++) const std::vector<double> &latest_mape = chem.error_stats_history.back().mape;
for (uint32_t j = 0; j < params.species_epsilon.size(); j++)
{ {
if (iter == chem.error_stats_history[i].iteration) if (params.species_epsilon[j] < latest_mape[j] && latest_mape[j] != 0)
{ {
for (uint32_t j = 0; j < params.species_epsilon.size(); j++) uint32_t rollback_iter = iter - (iter % params.control_iteration);
{
if (params.species_epsilon[j] < chem.error_stats_history[i].mape[j] && chem.error_stats_history[i].mape[j] != 0 && chem.control_iteration_counter > 1)
{
uint32_t rollback_iter = iter - params.control_iteration;
std::cout << chem.getField().GetProps()[j] << " with a MAPE value of " << chem.error_stats_history[i].mape[j] << " exceeds epsilon of " std::cout << chem.getField().GetProps()[j] << " with a MAPE value of " << latest_mape[j] << " exceeds epsilon of "
<< params.species_epsilon[j] << "! " << std::endl; << params.species_epsilon[j] << "! " << std::endl;
Checkpoint_s checkpoint_read{.field = chem.getField()}; Checkpoint_s checkpoint_read{.field = chem.getField()};
read_checkpoint("checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read); read_checkpoint("checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read);
iter = checkpoint_read.iteration; iter = checkpoint_read.iteration;
chem.control_iteration_counter--; return true;
}
return true; }
} MSG("All spezies are below their threshold values");
} return false;
}
void updatePenaltyLogic(RuntimeParameters &params, bool roolback_happend)
{
if (roolback_happend)
{
params.rollback_simulation = true;
params.penalty_counter = params.penalty_iteration;
std::cout << "Penalty counter reset to: " << params.penalty_counter << std::endl;
MSG("Rollback! Penalty phase started for " + std::to_string(params.penalty_iteration) + " iterations.");
}
else
{
if (params.rollback_simulation && params.penalty_counter == 0)
{
params.rollback_simulation = false;
MSG("Penalty phase ended. Interpolation re-enabled.");
}
else if (!params.rollback_simulation)
{
params.penalty_iteration = std::min(params.penalty_iteration *= 2, params.max_penalty_iteration);
MSG("Stable surrogate phase detected. Penalty iteration doubled to " + std::to_string(params.penalty_iteration) + " iterations.");
} }
} }
return false;
} }
static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params, static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
@ -344,13 +367,21 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
} }
R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps()); R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps());
params.next_penalty_check = params.penalty_iteration;
/* SIMULATION LOOP */ /* SIMULATION LOOP */
double dSimTime{0}; double dSimTime{0};
for (uint32_t iter = 1; iter < maxiter + 1; iter++) for (uint32_t iter = 1; iter < maxiter + 1; iter++)
{ {
// Penalty countdown
if (params.rollback_simulation && params.penalty_counter > 0)
{
params.penalty_counter--;
std::cout << "Penalty counter: " << params.penalty_counter << std::endl;
}
params.control_iteration_active = (iter % params.control_iteration == 0 && iter != 0); params.control_iteration_active = (iter % params.control_iteration == 0 /* && iter != 0 */);
double start_t = MPI_Wtime(); double start_t = MPI_Wtime();
@ -459,12 +490,6 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
// TODO: write checkpoint // TODO: write checkpoint
// checkpoint struct --> field and iteration // checkpoint struct --> field and iteration
/*else if (iter == 2) {
Checkpoint_s checkpoint_read{.field = chem.getField()};
read_checkpoint("checkpoint1.hdf5", checkpoint_read);
iter = checkpoint_read.iteration;
}*/
diffusion.getField().update(chem.getField()); diffusion.getField().update(chem.getField());
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" + MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +
@ -473,12 +498,18 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
if (iter % params.control_iteration == 0) if (iter % params.control_iteration == 0)
{ {
writeStatsToCSV(chem.error_stats_history, chem.getField().GetProps(), "stats_overview"); writeStatsToCSV(chem.error_stats_history, chem.getField().GetProps(), "stats_overview");
write_checkpoint("checkpoint" + std::to_string(iter) + ".hdf5", write_checkpoint("checkpoint" + std::to_string(iter) + ".hdf5",
{.field = chem.getField(), .iteration = iter}); {.field = chem.getField(), .iteration = iter});
checkAndRollback(chem, params, iter);
} }
if (iter == params.next_penalty_check)
{
bool roolback_happend = checkAndRollback(chem, params, iter);
updatePenaltyLogic(params, roolback_happend);
params.next_penalty_check = iter + params.penalty_iteration;
}
// MSG(); // MSG();
} // END SIMULATION LOOP } // END SIMULATION LOOP

View File

@ -52,6 +52,11 @@ struct RuntimeParameters {
bool print_progress = false; bool print_progress = false;
std::uint32_t penalty_iteration = 0;
std::uint32_t max_penalty_iteration = 0;
std::uint32_t penalty_counter = 0;
std::uint32_t next_penalty_check = 0;
bool rollback_simulation = false;
bool control_iteration_active = false; bool control_iteration_active = false;
std::uint32_t control_iteration = 1; std::uint32_t control_iteration = 1;