feat(control): dynamic prototype, penalty_iteration, error while disabling surrogate fixed

This commit is contained in:
rastogi 2025-10-02 13:20:53 +02:00 committed by Max Lübke
parent 41d1a9895c
commit 15e397ecf2
6 changed files with 171 additions and 114 deletions

View File

@ -28,12 +28,12 @@ if (POET_PREPROCESS_BENCHS)
endif()
# as tug will also pull in doctest as a dependency
set(TUG_ENABLE_TESTING ON CACHE BOOL "" FORCE)
set(TUG_ENABLE_TESTING OFF CACHE BOOL "" FORCE)
add_subdirectory(ext/tug EXCLUDE_FROM_ALL)
add_subdirectory(ext/iphreeqc EXCLUDE_FROM_ALL)
option(POET_ENABLE_TESTING "Build test suite for POET" ON)
option(POET_ENABLE_TESTING "Build test suite for POET" OFF)
if (POET_ENABLE_TESTING)
add_subdirectory(test)

View File

@ -299,6 +299,7 @@ namespace poet
CHEM_DHT_SIGNIF_VEC,
CHEM_DHT_SNAPS,
CHEM_DHT_READ_FILE,
CHEM_INTERP,
CHEM_IP_ENABLE,
CHEM_IP_MIN_ENTRIES,
CHEM_IP_SIGNIF_VEC,

View File

@ -448,6 +448,19 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
ftype = CHEM_WORK_LOOP;
PropagateFunctionType(ftype);
ftype = CHEM_INTERP;
PropagateFunctionType(ftype);
if(this->runtime_params->rollback_simulation){
this->interp_enabled = false;
int interp_flag = 0;
ChemBCast(&interp_flag, 1, MPI_INT);
} else {
this->interp_enabled = true;
int interp_flag = 1;
ChemBCast(&interp_flag, 1, MPI_INT);
}
MPI_Barrier(this->group_comm);
static uint32_t iteration = 0;

View File

@ -34,105 +34,112 @@ namespace poet
return ret_str;
}
void poet::ChemistryModule::WorkerLoop()
{
struct worker_s timings;
// HACK: defining the worker iteration count here, which will increment after
// each CHEM_ITER_END message
uint32_t iteration = 1;
bool loop = true;
while (loop)
void poet::ChemistryModule::WorkerLoop()
{
int func_type;
PropagateFunctionType(func_type);
struct worker_s timings;
switch (func_type)
// HACK: defining the worker iteration count here, which will increment after
// each CHEM_ITER_END message
uint32_t iteration = 1;
bool loop = true;
while (loop)
{
case CHEM_FIELD_INIT:
{
ChemBCast(&this->prop_count, 1, MPI_UINT32_T);
if (this->ai_surrogate_enabled)
int func_type;
PropagateFunctionType(func_type);
switch (func_type)
{
this->ai_surrogate_validity_vector.resize(
this->n_cells); // resize statt reserve?
}
break;
}
case CHEM_AI_BCAST_VALIDITY:
{
// Receive the index vector of valid ai surrogate predictions
MPI_Bcast(&this->ai_surrogate_validity_vector.front(), this->n_cells,
MPI_INT, 0, this->group_comm);
break;
}
case CHEM_WORK_LOOP:
{
WorkerProcessPkgs(timings, iteration);
break;
}
case CHEM_PERF:
{
int type;
ChemBCast(&type, 1, MPI_INT);
if (type < WORKER_DHT_HITS)
case CHEM_FIELD_INIT:
{
WorkerPerfToMaster(type, timings);
ChemBCast(&this->prop_count, 1, MPI_UINT32_T);
if (this->ai_surrogate_enabled)
{
this->ai_surrogate_validity_vector.resize(
this->n_cells); // resize statt reserve?
}
break;
}
WorkerMetricsToMaster(type);
break;
}
case CHEM_BREAK_MAIN_LOOP:
{
WorkerPostSim(iteration);
loop = false;
break;
}
default:
{
throw std::runtime_error("Worker received unknown tag from master.");
}
case CHEM_AI_BCAST_VALIDITY:
{
// Receive the index vector of valid ai surrogate predictions
MPI_Bcast(&this->ai_surrogate_validity_vector.front(), this->n_cells,
MPI_INT, 0, this->group_comm);
break;
}
case CHEM_INTERP:
{
int interp_flag;
ChemBCast(&interp_flag, 1, MPI_INT);
this->interp_enabled = (interp_flag == 1);
break;
}
case CHEM_WORK_LOOP:
{
WorkerProcessPkgs(timings, iteration);
break;
}
case CHEM_PERF:
{
int type;
ChemBCast(&type, 1, MPI_INT);
if (type < WORKER_DHT_HITS)
{
WorkerPerfToMaster(type, timings);
break;
}
WorkerMetricsToMaster(type);
break;
}
case CHEM_BREAK_MAIN_LOOP:
{
WorkerPostSim(iteration);
loop = false;
break;
}
default:
{
throw std::runtime_error("Worker received unknown tag from master.");
}
}
}
}
}
void poet::ChemistryModule::WorkerProcessPkgs(struct worker_s &timings,
uint32_t &iteration)
{
MPI_Status probe_status;
bool loop = true;
MPI_Barrier(this->group_comm);
while (loop)
void poet::ChemistryModule::WorkerProcessPkgs(struct worker_s &timings,
uint32_t &iteration)
{
double idle_a = MPI_Wtime();
MPI_Probe(0, MPI_ANY_TAG, this->group_comm, &probe_status);
double idle_b = MPI_Wtime();
MPI_Status probe_status;
bool loop = true;
switch (probe_status.MPI_TAG)
{
case LOOP_WORK:
{
timings.idle_t += idle_b - idle_a;
int count;
MPI_Get_count(&probe_status, MPI_DOUBLE, &count);
MPI_Barrier(this->group_comm);
WorkerDoWork(probe_status, count, timings);
break;
}
case LOOP_END:
while (loop)
{
WorkerPostIter(probe_status, iteration);
iteration++;
loop = false;
break;
}
double idle_a = MPI_Wtime();
MPI_Probe(0, MPI_ANY_TAG, this->group_comm, &probe_status);
double idle_b = MPI_Wtime();
switch (probe_status.MPI_TAG)
{
case LOOP_WORK:
{
timings.idle_t += idle_b - idle_a;
int count;
MPI_Get_count(&probe_status, MPI_DOUBLE, &count);
WorkerDoWork(probe_status, count, timings);
break;
}
case LOOP_END:
{
WorkerPostIter(probe_status, iteration);
iteration++;
loop = false;
break;
}
}
}
}
}
void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
int double_count,
@ -254,7 +261,7 @@ namespace poet
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++)
{
if (!s_curr_wp.mapping[wp_i] == CHEM_PQC) // only copy if surrogate was used
if (s_curr_wp.mapping[wp_i] != CHEM_PQC) // only copy if surrogate was used
{
std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(),
mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i);

View File

@ -270,6 +270,10 @@ int parseInitValues(int argc, char **argv, RuntimeParameters &params)
Rcpp::as<uint32_t>(global_rt_setup->operator[]("control_iteration"));
params.species_epsilon =
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("species_epsilon"));
params.penalty_iteration =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("penalty_iteration"));
params.max_penalty_iteration =
Rcpp::as<uint32_t>(global_rt_setup->operator[]("max_penalty_iteration"));
}
catch (const std::exception &e)
{
@ -302,31 +306,50 @@ void call_master_iter_end(RInside &R, const Field &trans, const Field &chem)
bool checkAndRollback(ChemistryModule &chem, RuntimeParameters &params, uint32_t &iter)
{
for (uint32_t i = 0; i < chem.error_stats_history.size(); i++)
const std::vector<double> &latest_mape = chem.error_stats_history.back().mape;
for (uint32_t j = 0; j < params.species_epsilon.size(); j++)
{
if (iter == chem.error_stats_history[i].iteration)
if (params.species_epsilon[j] < latest_mape[j] && latest_mape[j] != 0)
{
for (uint32_t j = 0; j < params.species_epsilon.size(); j++)
{
if (params.species_epsilon[j] < chem.error_stats_history[i].mape[j] && chem.error_stats_history[i].mape[j] != 0 && chem.control_iteration_counter > 1)
{
uint32_t rollback_iter = iter - params.control_iteration;
uint32_t rollback_iter = iter - (iter % params.control_iteration);
std::cout << chem.getField().GetProps()[j] << " with a MAPE value of " << chem.error_stats_history[i].mape[j] << " exceeds epsilon of "
<< params.species_epsilon[j] << "! " << std::endl;
std::cout << chem.getField().GetProps()[j] << " with a MAPE value of " << latest_mape[j] << " exceeds epsilon of "
<< params.species_epsilon[j] << "! " << std::endl;
Checkpoint_s checkpoint_read{.field = chem.getField()};
read_checkpoint("checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read);
iter = checkpoint_read.iteration;
Checkpoint_s checkpoint_read{.field = chem.getField()};
read_checkpoint("checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read);
iter = checkpoint_read.iteration;
chem.control_iteration_counter--;
return true;
}
}
return true;
}
}
MSG("All spezies are below their threshold values");
return false;
}
void updatePenaltyLogic(RuntimeParameters &params, bool roolback_happend)
{
if (roolback_happend)
{
params.rollback_simulation = true;
params.penalty_counter = params.penalty_iteration;
std::cout << "Penalty counter reset to: " << params.penalty_counter << std::endl;
MSG("Rollback! Penalty phase started for " + std::to_string(params.penalty_iteration) + " iterations.");
}
else
{
if (params.rollback_simulation && params.penalty_counter == 0)
{
params.rollback_simulation = false;
MSG("Penalty phase ended. Interpolation re-enabled.");
}
else if (!params.rollback_simulation)
{
params.penalty_iteration = std::min(params.penalty_iteration *= 2, params.max_penalty_iteration);
MSG("Stable surrogate phase detected. Penalty iteration doubled to " + std::to_string(params.penalty_iteration) + " iterations.");
}
}
return false;
}
static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
@ -344,13 +367,21 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
}
R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps());
params.next_penalty_check = params.penalty_iteration;
/* SIMULATION LOOP */
double dSimTime{0};
for (uint32_t iter = 1; iter < maxiter + 1; iter++)
{
// Penalty countdown
if (params.rollback_simulation && params.penalty_counter > 0)
{
params.penalty_counter--;
std::cout << "Penalty counter: " << params.penalty_counter << std::endl;
}
params.control_iteration_active = (iter % params.control_iteration == 0 && iter != 0);
params.control_iteration_active = (iter % params.control_iteration == 0 /* && iter != 0 */);
double start_t = MPI_Wtime();
@ -459,12 +490,6 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
// TODO: write checkpoint
// checkpoint struct --> field and iteration
/*else if (iter == 2) {
Checkpoint_s checkpoint_read{.field = chem.getField()};
read_checkpoint("checkpoint1.hdf5", checkpoint_read);
iter = checkpoint_read.iteration;
}*/
diffusion.getField().update(chem.getField());
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +
@ -473,12 +498,18 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters &params,
if (iter % params.control_iteration == 0)
{
writeStatsToCSV(chem.error_stats_history, chem.getField().GetProps(), "stats_overview");
write_checkpoint("checkpoint" + std::to_string(iter) + ".hdf5",
{.field = chem.getField(), .iteration = iter});
checkAndRollback(chem, params, iter);
}
if (iter == params.next_penalty_check)
{
bool roolback_happend = checkAndRollback(chem, params, iter);
updatePenaltyLogic(params, roolback_happend);
params.next_penalty_check = iter + params.penalty_iteration;
}
// MSG();
} // END SIMULATION LOOP

View File

@ -52,6 +52,11 @@ struct RuntimeParameters {
bool print_progress = false;
std::uint32_t penalty_iteration = 0;
std::uint32_t max_penalty_iteration = 0;
std::uint32_t penalty_counter = 0;
std::uint32_t next_penalty_check = 0;
bool rollback_simulation = false;
bool control_iteration_active = false;
std::uint32_t control_iteration = 1;