mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-15 20:38:23 +01:00
feat(control): dynamic prototype, penalty_iteration, error while disabling surrogate fixed
This commit is contained in:
parent
41d1a9895c
commit
15e397ecf2
@ -28,12 +28,12 @@ if (POET_PREPROCESS_BENCHS)
|
||||
endif()
|
||||
|
||||
# as tug will also pull in doctest as a dependency
|
||||
set(TUG_ENABLE_TESTING ON CACHE BOOL "" FORCE)
|
||||
set(TUG_ENABLE_TESTING OFF CACHE BOOL "" FORCE)
|
||||
|
||||
add_subdirectory(ext/tug EXCLUDE_FROM_ALL)
|
||||
add_subdirectory(ext/iphreeqc EXCLUDE_FROM_ALL)
|
||||
|
||||
option(POET_ENABLE_TESTING "Build test suite for POET" ON)
|
||||
option(POET_ENABLE_TESTING "Build test suite for POET" OFF)
|
||||
|
||||
if (POET_ENABLE_TESTING)
|
||||
add_subdirectory(test)
|
||||
|
||||
@ -299,6 +299,7 @@ namespace poet
|
||||
CHEM_DHT_SIGNIF_VEC,
|
||||
CHEM_DHT_SNAPS,
|
||||
CHEM_DHT_READ_FILE,
|
||||
CHEM_INTERP,
|
||||
CHEM_IP_ENABLE,
|
||||
CHEM_IP_MIN_ENTRIES,
|
||||
CHEM_IP_SIGNIF_VEC,
|
||||
|
||||
@ -448,6 +448,19 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
||||
ftype = CHEM_WORK_LOOP;
|
||||
PropagateFunctionType(ftype);
|
||||
|
||||
ftype = CHEM_INTERP;
|
||||
PropagateFunctionType(ftype);
|
||||
|
||||
if(this->runtime_params->rollback_simulation){
|
||||
this->interp_enabled = false;
|
||||
int interp_flag = 0;
|
||||
ChemBCast(&interp_flag, 1, MPI_INT);
|
||||
} else {
|
||||
this->interp_enabled = true;
|
||||
int interp_flag = 1;
|
||||
ChemBCast(&interp_flag, 1, MPI_INT);
|
||||
}
|
||||
|
||||
MPI_Barrier(this->group_comm);
|
||||
|
||||
static uint32_t iteration = 0;
|
||||
|
||||
@ -67,6 +67,13 @@ namespace poet
|
||||
MPI_INT, 0, this->group_comm);
|
||||
break;
|
||||
}
|
||||
case CHEM_INTERP:
|
||||
{
|
||||
int interp_flag;
|
||||
ChemBCast(&interp_flag, 1, MPI_INT);
|
||||
this->interp_enabled = (interp_flag == 1);
|
||||
break;
|
||||
}
|
||||
case CHEM_WORK_LOOP:
|
||||
{
|
||||
WorkerProcessPkgs(timings, iteration);
|
||||
@ -254,7 +261,7 @@ namespace poet
|
||||
|
||||
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++)
|
||||
{
|
||||
if (!s_curr_wp.mapping[wp_i] == CHEM_PQC) // only copy if surrogate was used
|
||||
if (s_curr_wp.mapping[wp_i] != CHEM_PQC) // only copy if surrogate was used
|
||||
{
|
||||
std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(),
|
||||
mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i);
|
||||
|
||||
73
src/poet.cpp
73
src/poet.cpp
@ -270,6 +270,10 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms)
|
||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("control_iteration"));
|
||||
params.species_epsilon =
|
||||
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("species_epsilon"));
|
||||
params.penalty_iteration =
|
||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("penalty_iteration"));
|
||||
params.max_penalty_iteration =
|
||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("max_penalty_iteration"));
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
@ -302,33 +306,52 @@ void call_master_iter_end(RInside &R, const Field &trans, const Field &chem)
|
||||
|
||||
bool checkAndRollback(ChemistryModule &chem, RuntimeParameters ¶ms, uint32_t &iter)
|
||||
{
|
||||
for (uint32_t i = 0; i < chem.error_stats_history.size(); i++)
|
||||
{
|
||||
if (iter == chem.error_stats_history[i].iteration)
|
||||
{
|
||||
const std::vector<double> &latest_mape = chem.error_stats_history.back().mape;
|
||||
|
||||
for (uint32_t j = 0; j < params.species_epsilon.size(); j++)
|
||||
{
|
||||
if (params.species_epsilon[j] < chem.error_stats_history[i].mape[j] && chem.error_stats_history[i].mape[j] != 0 && chem.control_iteration_counter > 1)
|
||||
if (params.species_epsilon[j] < latest_mape[j] && latest_mape[j] != 0)
|
||||
{
|
||||
uint32_t rollback_iter = iter - params.control_iteration;
|
||||
uint32_t rollback_iter = iter - (iter % params.control_iteration);
|
||||
|
||||
std::cout << chem.getField().GetProps()[j] << " with a MAPE value of " << chem.error_stats_history[i].mape[j] << " exceeds epsilon of "
|
||||
std::cout << chem.getField().GetProps()[j] << " with a MAPE value of " << latest_mape[j] << " exceeds epsilon of "
|
||||
<< params.species_epsilon[j] << "! " << std::endl;
|
||||
|
||||
Checkpoint_s checkpoint_read{.field = chem.getField()};
|
||||
read_checkpoint("checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read);
|
||||
iter = checkpoint_read.iteration;
|
||||
|
||||
chem.control_iteration_counter--;
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
MSG("All spezies are below their threshold values");
|
||||
return false;
|
||||
}
|
||||
|
||||
void updatePenaltyLogic(RuntimeParameters ¶ms, bool roolback_happend)
|
||||
{
|
||||
if (roolback_happend)
|
||||
{
|
||||
params.rollback_simulation = true;
|
||||
params.penalty_counter = params.penalty_iteration;
|
||||
std::cout << "Penalty counter reset to: " << params.penalty_counter << std::endl;
|
||||
MSG("Rollback! Penalty phase started for " + std::to_string(params.penalty_iteration) + " iterations.");
|
||||
}
|
||||
else
|
||||
{
|
||||
if (params.rollback_simulation && params.penalty_counter == 0)
|
||||
{
|
||||
params.rollback_simulation = false;
|
||||
MSG("Penalty phase ended. Interpolation re-enabled.");
|
||||
}
|
||||
else if (!params.rollback_simulation)
|
||||
{
|
||||
params.penalty_iteration = std::min(params.penalty_iteration *= 2, params.max_penalty_iteration);
|
||||
MSG("Stable surrogate phase detected. Penalty iteration doubled to " + std::to_string(params.penalty_iteration) + " iterations.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
DiffusionModule &diffusion,
|
||||
ChemistryModule &chem)
|
||||
@ -344,13 +367,21 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
}
|
||||
R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps());
|
||||
|
||||
params.next_penalty_check = params.penalty_iteration;
|
||||
|
||||
/* SIMULATION LOOP */
|
||||
|
||||
double dSimTime{0};
|
||||
for (uint32_t iter = 1; iter < maxiter + 1; iter++)
|
||||
{
|
||||
// Penalty countdown
|
||||
if (params.rollback_simulation && params.penalty_counter > 0)
|
||||
{
|
||||
params.penalty_counter--;
|
||||
std::cout << "Penalty counter: " << params.penalty_counter << std::endl;
|
||||
}
|
||||
|
||||
params.control_iteration_active = (iter % params.control_iteration == 0 && iter != 0);
|
||||
params.control_iteration_active = (iter % params.control_iteration == 0 /* && iter != 0 */);
|
||||
|
||||
double start_t = MPI_Wtime();
|
||||
|
||||
@ -459,12 +490,6 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
// TODO: write checkpoint
|
||||
// checkpoint struct --> field and iteration
|
||||
|
||||
/*else if (iter == 2) {
|
||||
Checkpoint_s checkpoint_read{.field = chem.getField()};
|
||||
read_checkpoint("checkpoint1.hdf5", checkpoint_read);
|
||||
iter = checkpoint_read.iteration;
|
||||
}*/
|
||||
|
||||
diffusion.getField().update(chem.getField());
|
||||
|
||||
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +
|
||||
@ -473,12 +498,18 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||
if (iter % params.control_iteration == 0)
|
||||
{
|
||||
writeStatsToCSV(chem.error_stats_history, chem.getField().GetProps(), "stats_overview");
|
||||
|
||||
write_checkpoint("checkpoint" + std::to_string(iter) + ".hdf5",
|
||||
{.field = chem.getField(), .iteration = iter});
|
||||
checkAndRollback(chem, params, iter);
|
||||
|
||||
}
|
||||
|
||||
if (iter == params.next_penalty_check)
|
||||
{
|
||||
bool roolback_happend = checkAndRollback(chem, params, iter);
|
||||
updatePenaltyLogic(params, roolback_happend);
|
||||
|
||||
params.next_penalty_check = iter + params.penalty_iteration;
|
||||
}
|
||||
|
||||
// MSG();
|
||||
} // END SIMULATION LOOP
|
||||
|
||||
|
||||
@ -52,6 +52,11 @@ struct RuntimeParameters {
|
||||
|
||||
bool print_progress = false;
|
||||
|
||||
std::uint32_t penalty_iteration = 0;
|
||||
std::uint32_t max_penalty_iteration = 0;
|
||||
std::uint32_t penalty_counter = 0;
|
||||
std::uint32_t next_penalty_check = 0;
|
||||
bool rollback_simulation = false;
|
||||
bool control_iteration_active = false;
|
||||
std::uint32_t control_iteration = 1;
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user