mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-16 12:54:50 +01:00
Added Control Module
This commit is contained in:
parent
f9e7793635
commit
c4d1682b3a
Binary file not shown.
Binary file not shown.
@ -283,17 +283,6 @@ inline void poet::ChemistryModule::MasterSendPkgs(
|
|||||||
MPI_Send(send_buffer.data(), send_buffer.size(), MPI_DOUBLE, p + 1,
|
MPI_Send(send_buffer.data(), send_buffer.size(), MPI_DOUBLE, p + 1,
|
||||||
LOOP_WORK, this->group_comm);
|
LOOP_WORK, this->group_comm);
|
||||||
|
|
||||||
/* ---- DEBUG LOG (Sender side) ---- */
|
|
||||||
std::cout << "[DEBUG][rank=" << p+1
|
|
||||||
<< "] sending WP " << (count_pkgs - 1)
|
|
||||||
<< " to worker rank " << (p + 1)
|
|
||||||
<< " | len=" << send_buffer.size()
|
|
||||||
<< " | start index=" << wp_start_index
|
|
||||||
<< " | second element=" << send_buffer[1]
|
|
||||||
<< " | pkg size=" << local_work_package_size
|
|
||||||
<< std::endl;
|
|
||||||
/* -------------------------------- */
|
|
||||||
|
|
||||||
/* Mark that worker has work to do */
|
/* Mark that worker has work to do */
|
||||||
w_list[p].has_work = 1;
|
w_list[p].has_work = 1;
|
||||||
free_workers--;
|
free_workers--;
|
||||||
@ -466,7 +455,8 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
|||||||
shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count,
|
shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count,
|
||||||
wp_sizes_vector.size());
|
wp_sizes_vector.size());
|
||||||
|
|
||||||
std::vector<double> mpi_surr_buffer{mpi_buffer};
|
std::vector<double> mpi_surr_buffer;
|
||||||
|
mpi_surr_buffer.resize(mpi_buffer.size());
|
||||||
|
|
||||||
/* setup local variables */
|
/* setup local variables */
|
||||||
pkg_to_send = wp_sizes_vector.size();
|
pkg_to_send = wp_sizes_vector.size();
|
||||||
|
|||||||
@ -39,6 +39,7 @@ void poet::ChemistryModule::WorkerLoop() {
|
|||||||
// each CHEM_ITER_END message
|
// each CHEM_ITER_END message
|
||||||
uint32_t iteration = 1;
|
uint32_t iteration = 1;
|
||||||
bool loop = true;
|
bool loop = true;
|
||||||
|
|
||||||
while (loop) {
|
while (loop) {
|
||||||
int func_type;
|
int func_type;
|
||||||
PropagateFunctionType(func_type);
|
PropagateFunctionType(func_type);
|
||||||
@ -59,7 +60,7 @@ void poet::ChemistryModule::WorkerLoop() {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case CHEM_CTRL: {
|
case CHEM_CTRL: {
|
||||||
int control_flag = 0;
|
int control_flag ;
|
||||||
ChemBCast(&control_flag, 1, MPI_INT);
|
ChemBCast(&control_flag, 1, MPI_INT);
|
||||||
this->control_enabled = (control_flag == 1);
|
this->control_enabled = (control_flag == 1);
|
||||||
break;
|
break;
|
||||||
@ -143,7 +144,6 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
|||||||
|
|
||||||
/* decrement count of work_package by BUFFER_OFFSET */
|
/* decrement count of work_package by BUFFER_OFFSET */
|
||||||
count -= BUFFER_OFFSET;
|
count -= BUFFER_OFFSET;
|
||||||
|
|
||||||
/* check for changes on all additional variables given by the 'header' of
|
/* check for changes on all additional variables given by the 'header' of
|
||||||
* mpi_buffer */
|
* mpi_buffer */
|
||||||
|
|
||||||
@ -162,12 +162,6 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
|||||||
// current work package start location in field
|
// current work package start location in field
|
||||||
wp_start_index = mpi_buffer[count + 4];
|
wp_start_index = mpi_buffer[count + 4];
|
||||||
|
|
||||||
std::cout << "[DEBUG][rank=" << this->comm_rank << "] WP " << counter
|
|
||||||
<< " len=" << count << " | second element: " << mpi_buffer[1]
|
|
||||||
<< " | iteration=" << iteration << " | dt=" << dt
|
|
||||||
<< " | simtime=" << current_sim_time
|
|
||||||
<< " | start_index=" << wp_start_index << std::endl;
|
|
||||||
|
|
||||||
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
|
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
|
||||||
s_curr_wp.input[wp_i] =
|
s_curr_wp.input[wp_i] =
|
||||||
std::vector<double>(mpi_buffer.begin() + this->prop_count * wp_i,
|
std::vector<double>(mpi_buffer.begin() + this->prop_count * wp_i,
|
||||||
@ -207,7 +201,8 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
|||||||
|
|
||||||
if (control_enabled) {
|
if (control_enabled) {
|
||||||
for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) {
|
for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) {
|
||||||
s_curr_wp_control.output[wp_i] = std::vector<double>(prop_count, 0.0);
|
s_curr_wp_control.output[wp_i] =
|
||||||
|
std::vector<double>(this->prop_count, 0.0);
|
||||||
s_curr_wp_control.mapping[wp_i] = 0;
|
s_curr_wp_control.mapping[wp_i] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -219,8 +214,43 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
|||||||
|
|
||||||
phreeqc_time_end = MPI_Wtime();
|
phreeqc_time_end = MPI_Wtime();
|
||||||
|
|
||||||
count =
|
if (control_enabled) {
|
||||||
packResultsIntoBuffer(mpi_buffer, count, s_curr_wp, s_curr_wp_control);
|
|
||||||
|
|
||||||
|
std::size_t sur_wp_offset = s_curr_wp.size * this->prop_count;
|
||||||
|
|
||||||
|
mpi_buffer.resize(count + sur_wp_offset);
|
||||||
|
|
||||||
|
for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) {
|
||||||
|
std::copy(s_curr_wp_control.output[wp_i].begin(),
|
||||||
|
s_curr_wp_control.output[wp_i].end(),
|
||||||
|
mpi_buffer.begin() + this->prop_count * wp_i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// s_curr_wp only contains the interpolated data
|
||||||
|
// copy surrogate output after the the pqc output, mpi_buffer[pqc][interp]
|
||||||
|
|
||||||
|
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
|
||||||
|
if (s_curr_wp.mapping[wp_i] !=
|
||||||
|
CHEM_PQC) // only copy if surrogate was used
|
||||||
|
{
|
||||||
|
std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(),
|
||||||
|
mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i);
|
||||||
|
} else {
|
||||||
|
// if pqc was used, copy pqc results again
|
||||||
|
std::copy(s_curr_wp_control.output[wp_i].begin(),
|
||||||
|
s_curr_wp_control.output[wp_i].end(),
|
||||||
|
mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
count += sur_wp_offset;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
|
||||||
|
std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(),
|
||||||
|
mpi_buffer.begin() + this->prop_count * wp_i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* send results to master */
|
/* send results to master */
|
||||||
MPI_Request send_req;
|
MPI_Request send_req;
|
||||||
@ -245,40 +275,6 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
|||||||
MPI_Wait(&send_req, MPI_STATUS_IGNORE);
|
MPI_Wait(&send_req, MPI_STATUS_IGNORE);
|
||||||
}
|
}
|
||||||
|
|
||||||
int poet::ChemistryModule::packResultsIntoBuffer(
|
|
||||||
std::vector<double> &mpi_buffer, int base_count, const WorkPackage &wp,
|
|
||||||
const WorkPackage &wp_control) {
|
|
||||||
if (control_enabled) {
|
|
||||||
std::size_t wp_offset = wp_control.size * prop_count;
|
|
||||||
mpi_buffer.resize(base_count + wp_offset);
|
|
||||||
|
|
||||||
/* copy pqc outputs first */
|
|
||||||
for (std::size_t wp_i = 0; wp_i < wp_control.size; wp_i++) {
|
|
||||||
std::copy(wp_control.output[wp_i].begin(), wp_control.output[wp_i].end(),
|
|
||||||
mpi_buffer.begin() + prop_count * wp_i);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* copy surrogate output, only if it contains interpolated data, after the
|
|
||||||
* the pqc output, layout = mpi_buffer[pqc][interp] */
|
|
||||||
for (std::size_t wp_i = 0; wp_i < wp.size; wp_i++) {
|
|
||||||
const auto &wp_copy = wp.mapping[wp_i] != CHEM_PQC
|
|
||||||
? wp.output[wp_i]
|
|
||||||
: wp_control.output[wp_i];
|
|
||||||
|
|
||||||
std::copy(wp_copy.begin(), wp_copy.end(),
|
|
||||||
mpi_buffer.begin() + wp_offset + prop_count * wp_i);
|
|
||||||
}
|
|
||||||
return base_count + static_cast<int>(wp_offset);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
for (std::size_t wp_i = 0; wp_i < wp.size; wp_i++) {
|
|
||||||
std::copy(wp.output[wp_i].begin(), wp.output[wp_i].end(),
|
|
||||||
mpi_buffer.begin() + prop_count + wp_i);
|
|
||||||
}
|
|
||||||
return base_count;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void poet::ChemistryModule::WorkerPostIter(MPI_Status &prope_status,
|
void poet::ChemistryModule::WorkerPostIter(MPI_Status &prope_status,
|
||||||
uint32_t iteration) {
|
uint32_t iteration) {
|
||||||
MPI_Recv(NULL, 0, MPI_DOUBLE, 0, LOOP_END, this->group_comm,
|
MPI_Recv(NULL, 0, MPI_DOUBLE, 0, LOOP_END, this->group_comm,
|
||||||
|
|||||||
@ -42,22 +42,26 @@ void poet::ControlModule::endIteration(const uint32_t iter) {
|
|||||||
/* Control Logic*/
|
/* Control Logic*/
|
||||||
if (control_interval_enabled &&
|
if (control_interval_enabled &&
|
||||||
checkpoint_interval > 0 /*&& !rollback_enabled*/) {
|
checkpoint_interval > 0 /*&& !rollback_enabled*/) {
|
||||||
MSG("Writing checkpoint of iteration " + std::to_string(iter));
|
if (!chem) {
|
||||||
write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
|
MSG("chem pointer is null — skipping checkpoint/stats write");
|
||||||
{.field = chem->getField(), .iteration = iter});
|
} else {
|
||||||
writeStatsToCSV(error_history, species_names, out_dir, "stats_overview");
|
MSG("Writing checkpoint of iteration " + std::to_string(iter));
|
||||||
|
write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
|
||||||
|
{.field = chem->getField(), .iteration = iter});
|
||||||
|
writeStatsToCSV(error_history, species_names, out_dir, "stats_overview");
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
||||||
if (triggerRollbackIfExceeded(*chem, *params, iter)) {
|
if (triggerRollbackIfExceeded(*chem, *params, iter)) {
|
||||||
rollback_enabled = true;
|
rollback_enabled = true;
|
||||||
rollback_counter++;
|
rollback_counter++;
|
||||||
sur_disabled_counter = control_interval;
|
sur_disabled_counter = control_interval;
|
||||||
MSG("Interpolation disabled for the next " +
|
MSG("Interpolation disabled for the next " +
|
||||||
std::to_string(control_interval) + ".");
|
std::to_string(control_interval) + ".");
|
||||||
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
*/
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,8 +101,8 @@ bool poet::ControlModule::triggerRollbackIfExceeded(ChemistryModule &chem,
|
|||||||
|
|
||||||
MSG("[THRESHOLD EXCEEDED] " + props[i] +
|
MSG("[THRESHOLD EXCEEDED] " + props[i] +
|
||||||
" has MAPE = " + std::to_string(mape[i]) +
|
" has MAPE = " + std::to_string(mape[i]) +
|
||||||
" exceeding threshold = " + std::to_string(params.mape_threshold[i]) +
|
" exceeding threshold = " + std::to_string(params.mape_threshold[i])
|
||||||
" → rolling back to iteration " + std::to_string(rollback_iter));
|
+ " → rolling back to iteration " + std::to_string(rollback_iter));
|
||||||
|
|
||||||
Checkpoint_s checkpoint_read{.field = chem.getField()};
|
Checkpoint_s checkpoint_read{.field = chem.getField()};
|
||||||
read_checkpoint(params.out_dir,
|
read_checkpoint(params.out_dir,
|
||||||
@ -122,15 +126,54 @@ void poet::ControlModule::computeSpeciesErrors(
|
|||||||
global_iteration,
|
global_iteration,
|
||||||
/*rollback_counter*/ 0);
|
/*rollback_counter*/ 0);
|
||||||
|
|
||||||
|
if (reference_values.size() != surrogate_values.size()) {
|
||||||
|
MSG(" Reference and surrogate vectors differ in size: " +
|
||||||
|
std::to_string(reference_values.size()) + " vs " +
|
||||||
|
std::to_string(surrogate_values.size()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::size_t expected =
|
||||||
|
static_cast<std::size_t>(this->species_names.size()) * size_per_prop;
|
||||||
|
if (reference_values.size() < expected) {
|
||||||
|
std::cerr << "[CTRL ERROR] input vectors too small: expected >= "
|
||||||
|
<< expected << " entries, got " << reference_values.size()
|
||||||
|
<< "\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int idxBa = -1, idxCl = -1;
|
||||||
|
for (size_t k = 0; k < this->species_names.size(); ++k) {
|
||||||
|
if (this->species_names[k] == "Ba")
|
||||||
|
idxBa = (int)k;
|
||||||
|
if (this->species_names[k] == "Cl")
|
||||||
|
idxCl = (int)k;
|
||||||
|
}
|
||||||
|
if (idxBa < 0 || idxCl < 0) {
|
||||||
|
std::cerr << "[CTRL DIAG] Ba/Cl indices not found: Ba=" << idxBa
|
||||||
|
<< " Cl=" << idxCl << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
for (uint32_t i = 0; i < this->species_names.size(); ++i) {
|
for (uint32_t i = 0; i < this->species_names.size(); ++i) {
|
||||||
double err_sum = 0.0;
|
double err_sum = 0.0;
|
||||||
double sqr_err_sum = 0.0;
|
double sqr_err_sum = 0.0;
|
||||||
uint32_t base_idx = i * size_per_prop;
|
uint32_t base_idx = i * size_per_prop;
|
||||||
|
uint32_t nan_count = 0;
|
||||||
|
uint32_t valid_count = 0;
|
||||||
|
double ref_sum = 0.0, sur_sum = 0.0;
|
||||||
|
|
||||||
for (uint32_t j = 0; j < size_per_prop; ++j) {
|
for (uint32_t j = 0; j < size_per_prop; ++j) {
|
||||||
const double ref_value = reference_values[base_idx + j];
|
const double ref_value = reference_values[base_idx + j];
|
||||||
const double sur_value = surrogate_values[base_idx + j];
|
const double sur_value = surrogate_values[base_idx + j];
|
||||||
|
|
||||||
|
if (std::isnan(ref_value) || std::isnan(sur_value)) {
|
||||||
|
nan_count++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
valid_count++;
|
||||||
|
ref_sum += ref_value;
|
||||||
|
sur_sum += sur_value;
|
||||||
|
|
||||||
if (ref_value == 0.0) {
|
if (ref_value == 0.0) {
|
||||||
if (sur_value != 0.0) {
|
if (sur_value != 0.0) {
|
||||||
err_sum += 1.0;
|
err_sum += 1.0;
|
||||||
@ -143,10 +186,45 @@ void poet::ControlModule::computeSpeciesErrors(
|
|||||||
sqr_err_sum += alpha * alpha;
|
sqr_err_sum += alpha * alpha;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// sample printing (keeps previous behavior: species 5 and 6)
|
||||||
|
if (i == 5 || i == 6) {
|
||||||
|
std::cerr << "[CTRL SAMPLE] species_index=" << i
|
||||||
|
<< " name=" << this->species_names[i]
|
||||||
|
<< " base_idx=" << base_idx << " nan_count=" << nan_count
|
||||||
|
<< " valid_count=" << valid_count << std::endl;
|
||||||
|
uint32_t N = std::min<uint32_t>(size_per_prop, 20u);
|
||||||
|
std::cerr << "[CTRL SAMPLE] reference: ";
|
||||||
|
for (uint32_t j = 0; j < N; ++j)
|
||||||
|
std::cerr << reference_values[base_idx + j]
|
||||||
|
<< (j + 1 == N ? "\n" : " ");
|
||||||
|
std::cerr << "[CTRL SAMPLE] surrogate: ";
|
||||||
|
for (uint32_t j = 0; j < N; ++j)
|
||||||
|
std::cerr << surrogate_values[base_idx + j]
|
||||||
|
<< (j + 1 == N ? "\n" : " ");
|
||||||
|
}
|
||||||
|
|
||||||
species_error_stats.mape[i] = 100.0 * (err_sum / size_per_prop);
|
if (valid_count > 0) {
|
||||||
species_error_stats.rrmse[i] =
|
species_error_stats.mape[i] = 100.0 * (err_sum / valid_count);
|
||||||
(size_per_prop > 0) ? std::sqrt(sqr_err_sum / size_per_prop) : 0.0;
|
species_error_stats.rrmse[i] = std::sqrt(sqr_err_sum / valid_count);
|
||||||
|
} else {
|
||||||
|
species_error_stats.mape[i] = 0.0;
|
||||||
|
species_error_stats.rrmse[i] = 0.0;
|
||||||
|
std::cerr << "[CTRL WARN] no valid samples for species " << i << " ("
|
||||||
|
<< this->species_names[i] << "), setting errors to 0\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
// DEBUG: detailed diagnostics for Ba/Cl (or whichever indices)
|
||||||
|
if (this->species_names[i] == "Ba" || this->species_names[i] == "Cl") {
|
||||||
|
double mean_ref = (valid_count > 0) ? (ref_sum / valid_count) : 0.0;
|
||||||
|
double mean_sur = (valid_count > 0) ? (sur_sum / valid_count) : 0.0;
|
||||||
|
std::cerr << "[CTRL DIAG] species=" << this->species_names[i]
|
||||||
|
<< " idx=" << i << " base_idx=" << base_idx
|
||||||
|
<< " valid_count=" << valid_count << " nan_count=" << nan_count
|
||||||
|
<< " err_sum=" << err_sum << " sqr_err_sum=" << sqr_err_sum
|
||||||
|
<< " mean_ref=" << mean_ref << " mean_sur=" << mean_sur
|
||||||
|
<< " computed_MAPE=" << species_error_stats.mape[i]
|
||||||
|
<< " computed_RRMSE=" << species_error_stats.rrmse[i] << "\n";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
error_history.push_back(species_error_stats);
|
error_history.push_back(species_error_stats);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -30,6 +30,8 @@ public:
|
|||||||
|
|
||||||
void endIteration(const uint32_t iter);
|
void endIteration(const uint32_t iter);
|
||||||
|
|
||||||
|
void setChemistryModule(poet::ChemistryModule *c) { chem = c; }
|
||||||
|
|
||||||
// void BCastControlFlags();
|
// void BCastControlFlags();
|
||||||
|
|
||||||
//bool triggerRollbackIfExceeded(ChemistryModule &chem,
|
//bool triggerRollbackIfExceeded(ChemistryModule &chem,
|
||||||
|
|||||||
@ -650,8 +650,8 @@ int main(int argc, char *argv[]) {
|
|||||||
init_list.getChemistryInit(), MPI_COMM_WORLD);
|
init_list.getChemistryInit(), MPI_COMM_WORLD);
|
||||||
|
|
||||||
ControlModule control;
|
ControlModule control;
|
||||||
|
|
||||||
chemistry.setControlModule(&control);
|
chemistry.setControlModule(&control);
|
||||||
|
control.setChemistryModule(&chemistry);
|
||||||
|
|
||||||
const ChemistryModule::SurrogateSetup surr_setup = {
|
const ChemistryModule::SurrogateSetup surr_setup = {
|
||||||
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user