mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-15 12:28:22 +01:00
feat(control): dynamic prototype, penalty_iteration, error while disabling surrogate fixed
This commit is contained in:
parent
c10d35fabe
commit
7467bbe50a
@ -28,12 +28,12 @@ if (POET_PREPROCESS_BENCHS)
|
||||
endif()
|
||||
|
||||
# as tug will also pull in doctest as a dependency
|
||||
set(TUG_ENABLE_TESTING ON CACHE BOOL "" FORCE)
|
||||
set(TUG_ENABLE_TESTING OFF CACHE BOOL "" FORCE)
|
||||
|
||||
add_subdirectory(ext/tug EXCLUDE_FROM_ALL)
|
||||
add_subdirectory(ext/iphreeqc EXCLUDE_FROM_ALL)
|
||||
|
||||
option(POET_ENABLE_TESTING "Build test suite for POET" ON)
|
||||
option(POET_ENABLE_TESTING "Build test suite for POET" OFF)
|
||||
|
||||
if (POET_ENABLE_TESTING)
|
||||
add_subdirectory(test)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@ -299,6 +299,7 @@ namespace poet
|
||||
CHEM_DHT_SIGNIF_VEC,
|
||||
CHEM_DHT_SNAPS,
|
||||
CHEM_DHT_READ_FILE,
|
||||
CHEM_INTERP,
|
||||
CHEM_IP_ENABLE,
|
||||
CHEM_IP_MIN_ENTRIES,
|
||||
CHEM_IP_SIGNIF_VEC,
|
||||
|
||||
@ -448,6 +448,19 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
||||
ftype = CHEM_WORK_LOOP;
|
||||
PropagateFunctionType(ftype);
|
||||
|
||||
ftype = CHEM_INTERP;
|
||||
PropagateFunctionType(ftype);
|
||||
|
||||
if(this->runtime_params->rollback_simulation){
|
||||
this->interp_enabled = false;
|
||||
int interp_flag = 0;
|
||||
ChemBCast(&interp_flag, 1, MPI_INT);
|
||||
} else {
|
||||
this->interp_enabled = true;
|
||||
int interp_flag = 1;
|
||||
ChemBCast(&interp_flag, 1, MPI_INT);
|
||||
}
|
||||
|
||||
MPI_Barrier(this->group_comm);
|
||||
|
||||
static uint32_t iteration = 0;
|
||||
|
||||
@ -34,105 +34,112 @@ namespace poet
|
||||
return ret_str;
|
||||
}
|
||||
|
||||
void poet::ChemistryModule::WorkerLoop()
|
||||
{
|
||||
struct worker_s timings;
|
||||
|
||||
// HACK: defining the worker iteration count here, which will increment after
|
||||
// each CHEM_ITER_END message
|
||||
uint32_t iteration = 1;
|
||||
bool loop = true;
|
||||
|
||||
while (loop)
|
||||
void poet::ChemistryModule::WorkerLoop()
|
||||
{
|
||||
int func_type;
|
||||
PropagateFunctionType(func_type);
|
||||
struct worker_s timings;
|
||||
|
||||
switch (func_type)
|
||||
// HACK: defining the worker iteration count here, which will increment after
|
||||
// each CHEM_ITER_END message
|
||||
uint32_t iteration = 1;
|
||||
bool loop = true;
|
||||
|
||||
while (loop)
|
||||
{
|
||||
case CHEM_FIELD_INIT:
|
||||
{
|
||||
ChemBCast(&this->prop_count, 1, MPI_UINT32_T);
|
||||
if (this->ai_surrogate_enabled)
|
||||
int func_type;
|
||||
PropagateFunctionType(func_type);
|
||||
|
||||
switch (func_type)
|
||||
{
|
||||
this->ai_surrogate_validity_vector.resize(
|
||||
this->n_cells); // resize statt reserve?
|
||||
}
|
||||
break;
|
||||
}
|
||||
case CHEM_AI_BCAST_VALIDITY:
|
||||
{
|
||||
// Receive the index vector of valid ai surrogate predictions
|
||||
MPI_Bcast(&this->ai_surrogate_validity_vector.front(), this->n_cells,
|
||||
MPI_INT, 0, this->group_comm);
|
||||
break;
|
||||
}
|
||||
case CHEM_WORK_LOOP:
|
||||
{
|
||||
WorkerProcessPkgs(timings, iteration);
|
||||
break;
|
||||
}
|
||||
case CHEM_PERF:
|
||||
{
|
||||
int type;
|
||||
ChemBCast(&type, 1, MPI_INT);
|
||||
if (type < WORKER_DHT_HITS)
|
||||
case CHEM_FIELD_INIT:
|
||||
{
|
||||
WorkerPerfToMaster(type, timings);
|
||||
ChemBCast(&this->prop_count, 1, MPI_UINT32_T);
|
||||
if (this->ai_surrogate_enabled)
|
||||
{
|
||||
this->ai_surrogate_validity_vector.resize(
|
||||
this->n_cells); // resize statt reserve?
|
||||
}
|
||||
break;
|
||||
}
|
||||
WorkerMetricsToMaster(type);
|
||||
break;
|
||||
}
|
||||
case CHEM_BREAK_MAIN_LOOP:
|
||||
{
|
||||
WorkerPostSim(iteration);
|
||||
loop = false;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
throw std::runtime_error("Worker received unknown tag from master.");
|
||||
}
|
||||
case CHEM_AI_BCAST_VALIDITY:
|
||||
{
|
||||
// Receive the index vector of valid ai surrogate predictions
|
||||
MPI_Bcast(&this->ai_surrogate_validity_vector.front(), this->n_cells,
|
||||
MPI_INT, 0, this->group_comm);
|
||||
break;
|
||||
}
|
||||
case CHEM_INTERP:
|
||||
{
|
||||
int interp_flag;
|
||||
ChemBCast(&interp_flag, 1, MPI_INT);
|
||||
this->interp_enabled = (interp_flag == 1);
|
||||
break;
|
||||
}
|
||||
case CHEM_WORK_LOOP:
|
||||
{
|
||||
WorkerProcessPkgs(timings, iteration);
|
||||
break;
|
||||
}
|
||||
case CHEM_PERF:
|
||||
{
|
||||
int type;
|
||||
ChemBCast(&type, 1, MPI_INT);
|
||||
if (type < WORKER_DHT_HITS)
|
||||
{
|
||||
WorkerPerfToMaster(type, timings);
|
||||
break;
|
||||
}
|
||||
WorkerMetricsToMaster(type);
|
||||
break;
|
||||
}
|
||||
case CHEM_BREAK_MAIN_LOOP:
|
||||
{
|
||||
WorkerPostSim(iteration);
|
||||
loop = false;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
throw std::runtime_error("Worker received unknown tag from master.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void poet::ChemistryModule::WorkerProcessPkgs(struct worker_s &timings,
|
||||
uint32_t &iteration)
|
||||
{
|
||||
MPI_Status probe_status;
|
||||
bool loop = true;
|
||||
|
||||
MPI_Barrier(this->group_comm);
|
||||
|
||||
while (loop)
|
||||
void poet::ChemistryModule::WorkerProcessPkgs(struct worker_s &timings,
|
||||
uint32_t &iteration)
|
||||
{
|
||||
double idle_a = MPI_Wtime();
|
||||
MPI_Probe(0, MPI_ANY_TAG, this->group_comm, &probe_status);
|
||||
double idle_b = MPI_Wtime();
|
||||
MPI_Status probe_status;
|
||||
bool loop = true;
|
||||
|
||||
switch (probe_status.MPI_TAG)
|
||||
{
|
||||
case LOOP_WORK:
|
||||
{
|
||||
timings.idle_t += idle_b - idle_a;
|
||||
int count;
|
||||
MPI_Get_count(&probe_status, MPI_DOUBLE, &count);
|
||||
MPI_Barrier(this->group_comm);
|
||||
|
||||
WorkerDoWork(probe_status, count, timings);
|
||||
break;
|
||||
}
|
||||
case LOOP_END:
|
||||
while (loop)
|
||||
{
|
||||
WorkerPostIter(probe_status, iteration);
|
||||
iteration++;
|
||||
loop = false;
|
||||
break;
|
||||
}
|
||||
double idle_a = MPI_Wtime();
|
||||
MPI_Probe(0, MPI_ANY_TAG, this->group_comm, &probe_status);
|
||||
double idle_b = MPI_Wtime();
|
||||
|
||||
switch (probe_status.MPI_TAG)
|
||||
{
|
||||
case LOOP_WORK:
|
||||
{
|
||||
timings.idle_t += idle_b - idle_a;
|
||||
int count;
|
||||
MPI_Get_count(&probe_status, MPI_DOUBLE, &count);
|
||||
|
||||
WorkerDoWork(probe_status, count, timings);
|
||||
break;
|
||||
}
|
||||
case LOOP_END:
|
||||
{
|
||||
WorkerPostIter(probe_status, iteration);
|
||||
iteration++;
|
||||
loop = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
int double_count,
|
||||
@ -254,7 +261,7 @@ namespace poet
|
||||
|
||||
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++)
|
||||
{
|
||||
if (!s_curr_wp.mapping[wp_i] == CHEM_PQC) // only copy if surrogate was used
|
||||
if (s_curr_wp.mapping[wp_i] != CHEM_PQC) // only copy if surrogate was used
|
||||
{
|
||||
std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(),
|
||||
mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i);
|
||||
|
||||
87
src/poet.cpp
87
src/poet.cpp
@ -270,6 +270,10 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms)
|
||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("control_iteration"));
|
||||
params.species_epsilon =
|
||||
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("species_epsilon"));
|
||||
params.penalty_iteration =
|
||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("penalty_iteration"));
|
||||
params.max_penalty_iteration =
|
||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("max_penalty_iteration"));
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
@ -302,31 +306,50 @@ void call_master_iter_end(RInside &R, const Field &trans, const Field &chem)
|
||||
|
||||
bool checkAndRollback(ChemistryModule &chem, RuntimeParameters ¶ms, uint32_t &iter)
|
||||
{
|
||||
for (uint32_t i = 0; i < chem.error_stats_history.size(); i++)
|
||||
const std::vector<double> &latest_mape = chem.error_stats_history.back().mape;
|
||||
|
||||
for (uint32_t j = 0; j < params.species_epsilon.size(); j++)
|
||||
{
|
||||
if (iter == chem.error_stats_history[i].iteration)
|
||||
if (params.species_epsilon[j] < latest_mape[j] && latest_mape[j] != 0)
|
||||
{
|
||||
for (uint32_t j = 0; j < params.species_epsilon.size(); j++)
|
||||
{
|
||||
if (params.species_epsilon[j] < chem.error_stats_history[i].mape[j] && chem.error_stats_history[i].mape[j] != 0 && chem.control_iteration_counter > 1)
|
||||
{
|
||||
uint32_t rollback_iter = iter - params.control_iteration;
|
||||
uint32_t rollback_iter = iter - (iter % params.control_iteration);
|
||||
|
||||
std::cout << chem.getField().GetProps()[j] << " with a MAPE value of " << chem.error_stats_history[i].mape[j] << " exceeds epsilon of "
|
||||
<< params.species_epsilon[j] << "! " << std::endl;
|
||||
std::cout << chem.getField().GetProps()[j] << " with a MAPE value of " << latest_mape[j] << " exceeds epsilon of "
|
||||
<< params.species_epsilon[j] << "! " << std::endl;
|
||||
|
||||
Checkpoint_s checkpoint_read{.field = chem.getField()};
|
||||
read_checkpoint("checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read);
|
||||
iter = checkpoint_read.iteration;
|
||||
Checkpoint_s checkpoint_read{.field = chem.getField()};
|
||||
read_checkpoint("checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read);
|
||||
iter = checkpoint_read.iteration;
|
||||
|
||||
chem.control_iteration_counter--;
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
MSG("All spezies are below their threshold values");
|
||||
return false;
|
||||
}
|
||||
|
||||
void updatePenaltyLogic(RuntimeParameters ¶ms, bool roolback_happend)
|
||||
{
|
||||
if (roolback_happend)
|
||||
{
|
||||
params.rollback_simulation = true;
|
||||
params.penalty_counter = params.penalty_iteration;
|
||||
std::cout << "Penalty counter reset to: " << params.penalty_counter << std::endl;
|
||||
MSG("Rollback! Penalty phase started for " + std::to_string(params.penalty_iteration) + " iterations.");
|
||||
}
|
||||
else
|
||||
{
|
||||
if (params.rollback_simulation && params.penalty_counter == 0)
|
||||
{
|
||||
params.rollback_simulation = false;
|
||||
MSG("Penalty phase ended. Interpolation re-enabled.");
|
||||
}
|
||||
else if (!params.rollback_simulation)
|
||||
{
|
||||
params.penalty_iteration = std::min(params.penalty_iteration *= 2, params.max_penalty_iteration);
|
||||
MSG("Stable surrogate phase detected. Penalty iteration doubled to " + std::to_string(params.penalty_iteration) + " iterations.");
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
@ -344,13 +367,21 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
}
|
||||
R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps());
|
||||
|
||||
params.next_penalty_check = params.penalty_iteration;
|
||||
|
||||
/* SIMULATION LOOP */
|
||||
|
||||
double dSimTime{0};
|
||||
for (uint32_t iter = 1; iter < maxiter + 1; iter++)
|
||||
{
|
||||
// Penalty countdown
|
||||
if (params.rollback_simulation && params.penalty_counter > 0)
|
||||
{
|
||||
params.penalty_counter--;
|
||||
std::cout << "Penalty counter: " << params.penalty_counter << std::endl;
|
||||
}
|
||||
|
||||
params.control_iteration_active = (iter % params.control_iteration == 0 && iter != 0);
|
||||
params.control_iteration_active = (iter % params.control_iteration == 0 /* && iter != 0 */);
|
||||
|
||||
double start_t = MPI_Wtime();
|
||||
|
||||
@ -459,12 +490,6 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
// TODO: write checkpoint
|
||||
// checkpoint struct --> field and iteration
|
||||
|
||||
/*else if (iter == 2) {
|
||||
Checkpoint_s checkpoint_read{.field = chem.getField()};
|
||||
read_checkpoint("checkpoint1.hdf5", checkpoint_read);
|
||||
iter = checkpoint_read.iteration;
|
||||
}*/
|
||||
|
||||
diffusion.getField().update(chem.getField());
|
||||
|
||||
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +
|
||||
@ -473,12 +498,18 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
if (iter % params.control_iteration == 0)
|
||||
{
|
||||
writeStatsToCSV(chem.error_stats_history, chem.getField().GetProps(), "stats_overview");
|
||||
|
||||
write_checkpoint("checkpoint" + std::to_string(iter) + ".hdf5",
|
||||
{.field = chem.getField(), .iteration = iter});
|
||||
checkAndRollback(chem, params, iter);
|
||||
|
||||
}
|
||||
|
||||
if (iter == params.next_penalty_check)
|
||||
{
|
||||
bool roolback_happend = checkAndRollback(chem, params, iter);
|
||||
updatePenaltyLogic(params, roolback_happend);
|
||||
|
||||
params.next_penalty_check = iter + params.penalty_iteration;
|
||||
}
|
||||
|
||||
// MSG();
|
||||
} // END SIMULATION LOOP
|
||||
|
||||
|
||||
@ -52,6 +52,11 @@ struct RuntimeParameters {
|
||||
|
||||
bool print_progress = false;
|
||||
|
||||
std::uint32_t penalty_iteration = 0;
|
||||
std::uint32_t max_penalty_iteration = 0;
|
||||
std::uint32_t penalty_counter = 0;
|
||||
std::uint32_t next_penalty_check = 0;
|
||||
bool rollback_simulation = false;
|
||||
bool control_iteration_active = false;
|
||||
std::uint32_t control_iteration = 1;
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user