mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-15 20:38:23 +01:00
feat(control): dynamic prototype, penalty_iteration, error while disabling surrogate fixed
This commit is contained in:
parent
41d1a9895c
commit
15e397ecf2
@ -28,12 +28,12 @@ if (POET_PREPROCESS_BENCHS)
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
# as tug will also pull in doctest as a dependency
|
# as tug will also pull in doctest as a dependency
|
||||||
set(TUG_ENABLE_TESTING ON CACHE BOOL "" FORCE)
|
set(TUG_ENABLE_TESTING OFF CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
add_subdirectory(ext/tug EXCLUDE_FROM_ALL)
|
add_subdirectory(ext/tug EXCLUDE_FROM_ALL)
|
||||||
add_subdirectory(ext/iphreeqc EXCLUDE_FROM_ALL)
|
add_subdirectory(ext/iphreeqc EXCLUDE_FROM_ALL)
|
||||||
|
|
||||||
option(POET_ENABLE_TESTING "Build test suite for POET" ON)
|
option(POET_ENABLE_TESTING "Build test suite for POET" OFF)
|
||||||
|
|
||||||
if (POET_ENABLE_TESTING)
|
if (POET_ENABLE_TESTING)
|
||||||
add_subdirectory(test)
|
add_subdirectory(test)
|
||||||
|
|||||||
@ -299,6 +299,7 @@ namespace poet
|
|||||||
CHEM_DHT_SIGNIF_VEC,
|
CHEM_DHT_SIGNIF_VEC,
|
||||||
CHEM_DHT_SNAPS,
|
CHEM_DHT_SNAPS,
|
||||||
CHEM_DHT_READ_FILE,
|
CHEM_DHT_READ_FILE,
|
||||||
|
CHEM_INTERP,
|
||||||
CHEM_IP_ENABLE,
|
CHEM_IP_ENABLE,
|
||||||
CHEM_IP_MIN_ENTRIES,
|
CHEM_IP_MIN_ENTRIES,
|
||||||
CHEM_IP_SIGNIF_VEC,
|
CHEM_IP_SIGNIF_VEC,
|
||||||
|
|||||||
@ -448,6 +448,19 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
|||||||
ftype = CHEM_WORK_LOOP;
|
ftype = CHEM_WORK_LOOP;
|
||||||
PropagateFunctionType(ftype);
|
PropagateFunctionType(ftype);
|
||||||
|
|
||||||
|
ftype = CHEM_INTERP;
|
||||||
|
PropagateFunctionType(ftype);
|
||||||
|
|
||||||
|
if(this->runtime_params->rollback_simulation){
|
||||||
|
this->interp_enabled = false;
|
||||||
|
int interp_flag = 0;
|
||||||
|
ChemBCast(&interp_flag, 1, MPI_INT);
|
||||||
|
} else {
|
||||||
|
this->interp_enabled = true;
|
||||||
|
int interp_flag = 1;
|
||||||
|
ChemBCast(&interp_flag, 1, MPI_INT);
|
||||||
|
}
|
||||||
|
|
||||||
MPI_Barrier(this->group_comm);
|
MPI_Barrier(this->group_comm);
|
||||||
|
|
||||||
static uint32_t iteration = 0;
|
static uint32_t iteration = 0;
|
||||||
|
|||||||
@ -34,105 +34,112 @@ namespace poet
|
|||||||
return ret_str;
|
return ret_str;
|
||||||
}
|
}
|
||||||
|
|
||||||
void poet::ChemistryModule::WorkerLoop()
|
void poet::ChemistryModule::WorkerLoop()
|
||||||
{
|
|
||||||
struct worker_s timings;
|
|
||||||
|
|
||||||
// HACK: defining the worker iteration count here, which will increment after
|
|
||||||
// each CHEM_ITER_END message
|
|
||||||
uint32_t iteration = 1;
|
|
||||||
bool loop = true;
|
|
||||||
|
|
||||||
while (loop)
|
|
||||||
{
|
{
|
||||||
int func_type;
|
struct worker_s timings;
|
||||||
PropagateFunctionType(func_type);
|
|
||||||
|
|
||||||
switch (func_type)
|
// HACK: defining the worker iteration count here, which will increment after
|
||||||
|
// each CHEM_ITER_END message
|
||||||
|
uint32_t iteration = 1;
|
||||||
|
bool loop = true;
|
||||||
|
|
||||||
|
while (loop)
|
||||||
{
|
{
|
||||||
case CHEM_FIELD_INIT:
|
int func_type;
|
||||||
{
|
PropagateFunctionType(func_type);
|
||||||
ChemBCast(&this->prop_count, 1, MPI_UINT32_T);
|
|
||||||
if (this->ai_surrogate_enabled)
|
switch (func_type)
|
||||||
{
|
{
|
||||||
this->ai_surrogate_validity_vector.resize(
|
case CHEM_FIELD_INIT:
|
||||||
this->n_cells); // resize statt reserve?
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case CHEM_AI_BCAST_VALIDITY:
|
|
||||||
{
|
|
||||||
// Receive the index vector of valid ai surrogate predictions
|
|
||||||
MPI_Bcast(&this->ai_surrogate_validity_vector.front(), this->n_cells,
|
|
||||||
MPI_INT, 0, this->group_comm);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case CHEM_WORK_LOOP:
|
|
||||||
{
|
|
||||||
WorkerProcessPkgs(timings, iteration);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case CHEM_PERF:
|
|
||||||
{
|
|
||||||
int type;
|
|
||||||
ChemBCast(&type, 1, MPI_INT);
|
|
||||||
if (type < WORKER_DHT_HITS)
|
|
||||||
{
|
{
|
||||||
WorkerPerfToMaster(type, timings);
|
ChemBCast(&this->prop_count, 1, MPI_UINT32_T);
|
||||||
|
if (this->ai_surrogate_enabled)
|
||||||
|
{
|
||||||
|
this->ai_surrogate_validity_vector.resize(
|
||||||
|
this->n_cells); // resize statt reserve?
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
WorkerMetricsToMaster(type);
|
case CHEM_AI_BCAST_VALIDITY:
|
||||||
break;
|
{
|
||||||
}
|
// Receive the index vector of valid ai surrogate predictions
|
||||||
case CHEM_BREAK_MAIN_LOOP:
|
MPI_Bcast(&this->ai_surrogate_validity_vector.front(), this->n_cells,
|
||||||
{
|
MPI_INT, 0, this->group_comm);
|
||||||
WorkerPostSim(iteration);
|
break;
|
||||||
loop = false;
|
}
|
||||||
break;
|
case CHEM_INTERP:
|
||||||
}
|
{
|
||||||
default:
|
int interp_flag;
|
||||||
{
|
ChemBCast(&interp_flag, 1, MPI_INT);
|
||||||
throw std::runtime_error("Worker received unknown tag from master.");
|
this->interp_enabled = (interp_flag == 1);
|
||||||
}
|
break;
|
||||||
|
}
|
||||||
|
case CHEM_WORK_LOOP:
|
||||||
|
{
|
||||||
|
WorkerProcessPkgs(timings, iteration);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case CHEM_PERF:
|
||||||
|
{
|
||||||
|
int type;
|
||||||
|
ChemBCast(&type, 1, MPI_INT);
|
||||||
|
if (type < WORKER_DHT_HITS)
|
||||||
|
{
|
||||||
|
WorkerPerfToMaster(type, timings);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
WorkerMetricsToMaster(type);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case CHEM_BREAK_MAIN_LOOP:
|
||||||
|
{
|
||||||
|
WorkerPostSim(iteration);
|
||||||
|
loop = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
throw std::runtime_error("Worker received unknown tag from master.");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
void poet::ChemistryModule::WorkerProcessPkgs(struct worker_s &timings,
|
void poet::ChemistryModule::WorkerProcessPkgs(struct worker_s &timings,
|
||||||
uint32_t &iteration)
|
uint32_t &iteration)
|
||||||
{
|
|
||||||
MPI_Status probe_status;
|
|
||||||
bool loop = true;
|
|
||||||
|
|
||||||
MPI_Barrier(this->group_comm);
|
|
||||||
|
|
||||||
while (loop)
|
|
||||||
{
|
{
|
||||||
double idle_a = MPI_Wtime();
|
MPI_Status probe_status;
|
||||||
MPI_Probe(0, MPI_ANY_TAG, this->group_comm, &probe_status);
|
bool loop = true;
|
||||||
double idle_b = MPI_Wtime();
|
|
||||||
|
|
||||||
switch (probe_status.MPI_TAG)
|
MPI_Barrier(this->group_comm);
|
||||||
{
|
|
||||||
case LOOP_WORK:
|
|
||||||
{
|
|
||||||
timings.idle_t += idle_b - idle_a;
|
|
||||||
int count;
|
|
||||||
MPI_Get_count(&probe_status, MPI_DOUBLE, &count);
|
|
||||||
|
|
||||||
WorkerDoWork(probe_status, count, timings);
|
while (loop)
|
||||||
break;
|
|
||||||
}
|
|
||||||
case LOOP_END:
|
|
||||||
{
|
{
|
||||||
WorkerPostIter(probe_status, iteration);
|
double idle_a = MPI_Wtime();
|
||||||
iteration++;
|
MPI_Probe(0, MPI_ANY_TAG, this->group_comm, &probe_status);
|
||||||
loop = false;
|
double idle_b = MPI_Wtime();
|
||||||
break;
|
|
||||||
}
|
switch (probe_status.MPI_TAG)
|
||||||
|
{
|
||||||
|
case LOOP_WORK:
|
||||||
|
{
|
||||||
|
timings.idle_t += idle_b - idle_a;
|
||||||
|
int count;
|
||||||
|
MPI_Get_count(&probe_status, MPI_DOUBLE, &count);
|
||||||
|
|
||||||
|
WorkerDoWork(probe_status, count, timings);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case LOOP_END:
|
||||||
|
{
|
||||||
|
WorkerPostIter(probe_status, iteration);
|
||||||
|
iteration++;
|
||||||
|
loop = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||||
int double_count,
|
int double_count,
|
||||||
@ -254,7 +261,7 @@ namespace poet
|
|||||||
|
|
||||||
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++)
|
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++)
|
||||||
{
|
{
|
||||||
if (!s_curr_wp.mapping[wp_i] == CHEM_PQC) // only copy if surrogate was used
|
if (s_curr_wp.mapping[wp_i] != CHEM_PQC) // only copy if surrogate was used
|
||||||
{
|
{
|
||||||
std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(),
|
std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(),
|
||||||
mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i);
|
mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i);
|
||||||
|
|||||||
87
src/poet.cpp
87
src/poet.cpp
@ -270,6 +270,10 @@ int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms)
|
|||||||
Rcpp::as<uint32_t>(global_rt_setup->operator[]("control_iteration"));
|
Rcpp::as<uint32_t>(global_rt_setup->operator[]("control_iteration"));
|
||||||
params.species_epsilon =
|
params.species_epsilon =
|
||||||
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("species_epsilon"));
|
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("species_epsilon"));
|
||||||
|
params.penalty_iteration =
|
||||||
|
Rcpp::as<uint32_t>(global_rt_setup->operator[]("penalty_iteration"));
|
||||||
|
params.max_penalty_iteration =
|
||||||
|
Rcpp::as<uint32_t>(global_rt_setup->operator[]("max_penalty_iteration"));
|
||||||
}
|
}
|
||||||
catch (const std::exception &e)
|
catch (const std::exception &e)
|
||||||
{
|
{
|
||||||
@ -302,31 +306,50 @@ void call_master_iter_end(RInside &R, const Field &trans, const Field &chem)
|
|||||||
|
|
||||||
bool checkAndRollback(ChemistryModule &chem, RuntimeParameters ¶ms, uint32_t &iter)
|
bool checkAndRollback(ChemistryModule &chem, RuntimeParameters ¶ms, uint32_t &iter)
|
||||||
{
|
{
|
||||||
for (uint32_t i = 0; i < chem.error_stats_history.size(); i++)
|
const std::vector<double> &latest_mape = chem.error_stats_history.back().mape;
|
||||||
|
|
||||||
|
for (uint32_t j = 0; j < params.species_epsilon.size(); j++)
|
||||||
{
|
{
|
||||||
if (iter == chem.error_stats_history[i].iteration)
|
if (params.species_epsilon[j] < latest_mape[j] && latest_mape[j] != 0)
|
||||||
{
|
{
|
||||||
for (uint32_t j = 0; j < params.species_epsilon.size(); j++)
|
uint32_t rollback_iter = iter - (iter % params.control_iteration);
|
||||||
{
|
|
||||||
if (params.species_epsilon[j] < chem.error_stats_history[i].mape[j] && chem.error_stats_history[i].mape[j] != 0 && chem.control_iteration_counter > 1)
|
|
||||||
{
|
|
||||||
uint32_t rollback_iter = iter - params.control_iteration;
|
|
||||||
|
|
||||||
std::cout << chem.getField().GetProps()[j] << " with a MAPE value of " << chem.error_stats_history[i].mape[j] << " exceeds epsilon of "
|
std::cout << chem.getField().GetProps()[j] << " with a MAPE value of " << latest_mape[j] << " exceeds epsilon of "
|
||||||
<< params.species_epsilon[j] << "! " << std::endl;
|
<< params.species_epsilon[j] << "! " << std::endl;
|
||||||
|
|
||||||
Checkpoint_s checkpoint_read{.field = chem.getField()};
|
Checkpoint_s checkpoint_read{.field = chem.getField()};
|
||||||
read_checkpoint("checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read);
|
read_checkpoint("checkpoint" + std::to_string(rollback_iter) + ".hdf5", checkpoint_read);
|
||||||
iter = checkpoint_read.iteration;
|
iter = checkpoint_read.iteration;
|
||||||
|
|
||||||
chem.control_iteration_counter--;
|
return true;
|
||||||
|
}
|
||||||
return true;
|
}
|
||||||
}
|
MSG("All spezies are below their threshold values");
|
||||||
}
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void updatePenaltyLogic(RuntimeParameters ¶ms, bool roolback_happend)
|
||||||
|
{
|
||||||
|
if (roolback_happend)
|
||||||
|
{
|
||||||
|
params.rollback_simulation = true;
|
||||||
|
params.penalty_counter = params.penalty_iteration;
|
||||||
|
std::cout << "Penalty counter reset to: " << params.penalty_counter << std::endl;
|
||||||
|
MSG("Rollback! Penalty phase started for " + std::to_string(params.penalty_iteration) + " iterations.");
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if (params.rollback_simulation && params.penalty_counter == 0)
|
||||||
|
{
|
||||||
|
params.rollback_simulation = false;
|
||||||
|
MSG("Penalty phase ended. Interpolation re-enabled.");
|
||||||
|
}
|
||||||
|
else if (!params.rollback_simulation)
|
||||||
|
{
|
||||||
|
params.penalty_iteration = std::min(params.penalty_iteration *= 2, params.max_penalty_iteration);
|
||||||
|
MSG("Stable surrogate phase detected. Penalty iteration doubled to " + std::to_string(params.penalty_iteration) + " iterations.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
||||||
@ -344,13 +367,21 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
|||||||
}
|
}
|
||||||
R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps());
|
R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps());
|
||||||
|
|
||||||
|
params.next_penalty_check = params.penalty_iteration;
|
||||||
|
|
||||||
/* SIMULATION LOOP */
|
/* SIMULATION LOOP */
|
||||||
|
|
||||||
double dSimTime{0};
|
double dSimTime{0};
|
||||||
for (uint32_t iter = 1; iter < maxiter + 1; iter++)
|
for (uint32_t iter = 1; iter < maxiter + 1; iter++)
|
||||||
{
|
{
|
||||||
|
// Penalty countdown
|
||||||
|
if (params.rollback_simulation && params.penalty_counter > 0)
|
||||||
|
{
|
||||||
|
params.penalty_counter--;
|
||||||
|
std::cout << "Penalty counter: " << params.penalty_counter << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
params.control_iteration_active = (iter % params.control_iteration == 0 && iter != 0);
|
params.control_iteration_active = (iter % params.control_iteration == 0 /* && iter != 0 */);
|
||||||
|
|
||||||
double start_t = MPI_Wtime();
|
double start_t = MPI_Wtime();
|
||||||
|
|
||||||
@ -459,12 +490,6 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
|||||||
// TODO: write checkpoint
|
// TODO: write checkpoint
|
||||||
// checkpoint struct --> field and iteration
|
// checkpoint struct --> field and iteration
|
||||||
|
|
||||||
/*else if (iter == 2) {
|
|
||||||
Checkpoint_s checkpoint_read{.field = chem.getField()};
|
|
||||||
read_checkpoint("checkpoint1.hdf5", checkpoint_read);
|
|
||||||
iter = checkpoint_read.iteration;
|
|
||||||
}*/
|
|
||||||
|
|
||||||
diffusion.getField().update(chem.getField());
|
diffusion.getField().update(chem.getField());
|
||||||
|
|
||||||
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +
|
MSG("End of *coupling* iteration " + std::to_string(iter) + "/" +
|
||||||
@ -473,12 +498,18 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, RuntimeParameters ¶ms,
|
|||||||
if (iter % params.control_iteration == 0)
|
if (iter % params.control_iteration == 0)
|
||||||
{
|
{
|
||||||
writeStatsToCSV(chem.error_stats_history, chem.getField().GetProps(), "stats_overview");
|
writeStatsToCSV(chem.error_stats_history, chem.getField().GetProps(), "stats_overview");
|
||||||
|
|
||||||
write_checkpoint("checkpoint" + std::to_string(iter) + ".hdf5",
|
write_checkpoint("checkpoint" + std::to_string(iter) + ".hdf5",
|
||||||
{.field = chem.getField(), .iteration = iter});
|
{.field = chem.getField(), .iteration = iter});
|
||||||
checkAndRollback(chem, params, iter);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (iter == params.next_penalty_check)
|
||||||
|
{
|
||||||
|
bool roolback_happend = checkAndRollback(chem, params, iter);
|
||||||
|
updatePenaltyLogic(params, roolback_happend);
|
||||||
|
|
||||||
|
params.next_penalty_check = iter + params.penalty_iteration;
|
||||||
|
}
|
||||||
|
|
||||||
// MSG();
|
// MSG();
|
||||||
} // END SIMULATION LOOP
|
} // END SIMULATION LOOP
|
||||||
|
|
||||||
|
|||||||
@ -52,6 +52,11 @@ struct RuntimeParameters {
|
|||||||
|
|
||||||
bool print_progress = false;
|
bool print_progress = false;
|
||||||
|
|
||||||
|
std::uint32_t penalty_iteration = 0;
|
||||||
|
std::uint32_t max_penalty_iteration = 0;
|
||||||
|
std::uint32_t penalty_counter = 0;
|
||||||
|
std::uint32_t next_penalty_check = 0;
|
||||||
|
bool rollback_simulation = false;
|
||||||
bool control_iteration_active = false;
|
bool control_iteration_active = false;
|
||||||
std::uint32_t control_iteration = 1;
|
std::uint32_t control_iteration = 1;
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user