mirror of
https://git.gfz-potsdam.de/naaice/poet.git
synced 2025-12-15 20:38:23 +01:00
Added Control Module
This commit is contained in:
parent
f9e7793635
commit
c4d1682b3a
Binary file not shown.
Binary file not shown.
@ -282,17 +282,6 @@ inline void poet::ChemistryModule::MasterSendPkgs(
|
||||
// LOOP_WORK, this->group_comm);
|
||||
MPI_Send(send_buffer.data(), send_buffer.size(), MPI_DOUBLE, p + 1,
|
||||
LOOP_WORK, this->group_comm);
|
||||
|
||||
/* ---- DEBUG LOG (Sender side) ---- */
|
||||
std::cout << "[DEBUG][rank=" << p+1
|
||||
<< "] sending WP " << (count_pkgs - 1)
|
||||
<< " to worker rank " << (p + 1)
|
||||
<< " | len=" << send_buffer.size()
|
||||
<< " | start index=" << wp_start_index
|
||||
<< " | second element=" << send_buffer[1]
|
||||
<< " | pkg size=" << local_work_package_size
|
||||
<< std::endl;
|
||||
/* -------------------------------- */
|
||||
|
||||
/* Mark that worker has work to do */
|
||||
w_list[p].has_work = 1;
|
||||
@ -466,7 +455,8 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
|
||||
shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count,
|
||||
wp_sizes_vector.size());
|
||||
|
||||
std::vector<double> mpi_surr_buffer{mpi_buffer};
|
||||
std::vector<double> mpi_surr_buffer;
|
||||
mpi_surr_buffer.resize(mpi_buffer.size());
|
||||
|
||||
/* setup local variables */
|
||||
pkg_to_send = wp_sizes_vector.size();
|
||||
|
||||
@ -39,6 +39,7 @@ void poet::ChemistryModule::WorkerLoop() {
|
||||
// each CHEM_ITER_END message
|
||||
uint32_t iteration = 1;
|
||||
bool loop = true;
|
||||
|
||||
while (loop) {
|
||||
int func_type;
|
||||
PropagateFunctionType(func_type);
|
||||
@ -59,7 +60,7 @@ void poet::ChemistryModule::WorkerLoop() {
|
||||
break;
|
||||
}
|
||||
case CHEM_CTRL: {
|
||||
int control_flag = 0;
|
||||
int control_flag ;
|
||||
ChemBCast(&control_flag, 1, MPI_INT);
|
||||
this->control_enabled = (control_flag == 1);
|
||||
break;
|
||||
@ -143,7 +144,6 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
|
||||
/* decrement count of work_package by BUFFER_OFFSET */
|
||||
count -= BUFFER_OFFSET;
|
||||
|
||||
/* check for changes on all additional variables given by the 'header' of
|
||||
* mpi_buffer */
|
||||
|
||||
@ -162,12 +162,6 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
// current work package start location in field
|
||||
wp_start_index = mpi_buffer[count + 4];
|
||||
|
||||
std::cout << "[DEBUG][rank=" << this->comm_rank << "] WP " << counter
|
||||
<< " len=" << count << " | second element: " << mpi_buffer[1]
|
||||
<< " | iteration=" << iteration << " | dt=" << dt
|
||||
<< " | simtime=" << current_sim_time
|
||||
<< " | start_index=" << wp_start_index << std::endl;
|
||||
|
||||
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
|
||||
s_curr_wp.input[wp_i] =
|
||||
std::vector<double>(mpi_buffer.begin() + this->prop_count * wp_i,
|
||||
@ -207,7 +201,8 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
|
||||
if (control_enabled) {
|
||||
for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) {
|
||||
s_curr_wp_control.output[wp_i] = std::vector<double>(prop_count, 0.0);
|
||||
s_curr_wp_control.output[wp_i] =
|
||||
std::vector<double>(this->prop_count, 0.0);
|
||||
s_curr_wp_control.mapping[wp_i] = 0;
|
||||
}
|
||||
}
|
||||
@ -219,8 +214,43 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
|
||||
phreeqc_time_end = MPI_Wtime();
|
||||
|
||||
count =
|
||||
packResultsIntoBuffer(mpi_buffer, count, s_curr_wp, s_curr_wp_control);
|
||||
if (control_enabled) {
|
||||
|
||||
|
||||
std::size_t sur_wp_offset = s_curr_wp.size * this->prop_count;
|
||||
|
||||
mpi_buffer.resize(count + sur_wp_offset);
|
||||
|
||||
for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) {
|
||||
std::copy(s_curr_wp_control.output[wp_i].begin(),
|
||||
s_curr_wp_control.output[wp_i].end(),
|
||||
mpi_buffer.begin() + this->prop_count * wp_i);
|
||||
}
|
||||
|
||||
// s_curr_wp only contains the interpolated data
|
||||
// copy surrogate output after the the pqc output, mpi_buffer[pqc][interp]
|
||||
|
||||
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
|
||||
if (s_curr_wp.mapping[wp_i] !=
|
||||
CHEM_PQC) // only copy if surrogate was used
|
||||
{
|
||||
std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(),
|
||||
mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i);
|
||||
} else {
|
||||
// if pqc was used, copy pqc results again
|
||||
std::copy(s_curr_wp_control.output[wp_i].begin(),
|
||||
s_curr_wp_control.output[wp_i].end(),
|
||||
mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i);
|
||||
}
|
||||
}
|
||||
count += sur_wp_offset;
|
||||
|
||||
} else {
|
||||
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
|
||||
std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(),
|
||||
mpi_buffer.begin() + this->prop_count * wp_i);
|
||||
}
|
||||
}
|
||||
|
||||
/* send results to master */
|
||||
MPI_Request send_req;
|
||||
@ -245,40 +275,6 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
|
||||
MPI_Wait(&send_req, MPI_STATUS_IGNORE);
|
||||
}
|
||||
|
||||
int poet::ChemistryModule::packResultsIntoBuffer(
|
||||
std::vector<double> &mpi_buffer, int base_count, const WorkPackage &wp,
|
||||
const WorkPackage &wp_control) {
|
||||
if (control_enabled) {
|
||||
std::size_t wp_offset = wp_control.size * prop_count;
|
||||
mpi_buffer.resize(base_count + wp_offset);
|
||||
|
||||
/* copy pqc outputs first */
|
||||
for (std::size_t wp_i = 0; wp_i < wp_control.size; wp_i++) {
|
||||
std::copy(wp_control.output[wp_i].begin(), wp_control.output[wp_i].end(),
|
||||
mpi_buffer.begin() + prop_count * wp_i);
|
||||
}
|
||||
|
||||
/* copy surrogate output, only if it contains interpolated data, after the
|
||||
* the pqc output, layout = mpi_buffer[pqc][interp] */
|
||||
for (std::size_t wp_i = 0; wp_i < wp.size; wp_i++) {
|
||||
const auto &wp_copy = wp.mapping[wp_i] != CHEM_PQC
|
||||
? wp.output[wp_i]
|
||||
: wp_control.output[wp_i];
|
||||
|
||||
std::copy(wp_copy.begin(), wp_copy.end(),
|
||||
mpi_buffer.begin() + wp_offset + prop_count * wp_i);
|
||||
}
|
||||
return base_count + static_cast<int>(wp_offset);
|
||||
|
||||
} else {
|
||||
for (std::size_t wp_i = 0; wp_i < wp.size; wp_i++) {
|
||||
std::copy(wp.output[wp_i].begin(), wp.output[wp_i].end(),
|
||||
mpi_buffer.begin() + prop_count + wp_i);
|
||||
}
|
||||
return base_count;
|
||||
}
|
||||
}
|
||||
|
||||
void poet::ChemistryModule::WorkerPostIter(MPI_Status &prope_status,
|
||||
uint32_t iteration) {
|
||||
MPI_Recv(NULL, 0, MPI_DOUBLE, 0, LOOP_END, this->group_comm,
|
||||
|
||||
@ -42,22 +42,26 @@ void poet::ControlModule::endIteration(const uint32_t iter) {
|
||||
/* Control Logic*/
|
||||
if (control_interval_enabled &&
|
||||
checkpoint_interval > 0 /*&& !rollback_enabled*/) {
|
||||
MSG("Writing checkpoint of iteration " + std::to_string(iter));
|
||||
write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
|
||||
{.field = chem->getField(), .iteration = iter});
|
||||
writeStatsToCSV(error_history, species_names, out_dir, "stats_overview");
|
||||
if (!chem) {
|
||||
MSG("chem pointer is null — skipping checkpoint/stats write");
|
||||
} else {
|
||||
MSG("Writing checkpoint of iteration " + std::to_string(iter));
|
||||
write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
|
||||
{.field = chem->getField(), .iteration = iter});
|
||||
writeStatsToCSV(error_history, species_names, out_dir, "stats_overview");
|
||||
|
||||
/*
|
||||
/*
|
||||
|
||||
if (triggerRollbackIfExceeded(*chem, *params, iter)) {
|
||||
rollback_enabled = true;
|
||||
rollback_counter++;
|
||||
sur_disabled_counter = control_interval;
|
||||
MSG("Interpolation disabled for the next " +
|
||||
std::to_string(control_interval) + ".");
|
||||
if (triggerRollbackIfExceeded(*chem, *params, iter)) {
|
||||
rollback_enabled = true;
|
||||
rollback_counter++;
|
||||
sur_disabled_counter = control_interval;
|
||||
MSG("Interpolation disabled for the next " +
|
||||
std::to_string(control_interval) + ".");
|
||||
}
|
||||
|
||||
*/
|
||||
}
|
||||
|
||||
*/
|
||||
}
|
||||
}
|
||||
|
||||
@ -97,8 +101,8 @@ bool poet::ControlModule::triggerRollbackIfExceeded(ChemistryModule &chem,
|
||||
|
||||
MSG("[THRESHOLD EXCEEDED] " + props[i] +
|
||||
" has MAPE = " + std::to_string(mape[i]) +
|
||||
" exceeding threshold = " + std::to_string(params.mape_threshold[i]) +
|
||||
" → rolling back to iteration " + std::to_string(rollback_iter));
|
||||
" exceeding threshold = " + std::to_string(params.mape_threshold[i])
|
||||
+ " → rolling back to iteration " + std::to_string(rollback_iter));
|
||||
|
||||
Checkpoint_s checkpoint_read{.field = chem.getField()};
|
||||
read_checkpoint(params.out_dir,
|
||||
@ -122,15 +126,54 @@ void poet::ControlModule::computeSpeciesErrors(
|
||||
global_iteration,
|
||||
/*rollback_counter*/ 0);
|
||||
|
||||
if (reference_values.size() != surrogate_values.size()) {
|
||||
MSG(" Reference and surrogate vectors differ in size: " +
|
||||
std::to_string(reference_values.size()) + " vs " +
|
||||
std::to_string(surrogate_values.size()));
|
||||
return;
|
||||
}
|
||||
|
||||
const std::size_t expected =
|
||||
static_cast<std::size_t>(this->species_names.size()) * size_per_prop;
|
||||
if (reference_values.size() < expected) {
|
||||
std::cerr << "[CTRL ERROR] input vectors too small: expected >= "
|
||||
<< expected << " entries, got " << reference_values.size()
|
||||
<< "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
int idxBa = -1, idxCl = -1;
|
||||
for (size_t k = 0; k < this->species_names.size(); ++k) {
|
||||
if (this->species_names[k] == "Ba")
|
||||
idxBa = (int)k;
|
||||
if (this->species_names[k] == "Cl")
|
||||
idxCl = (int)k;
|
||||
}
|
||||
if (idxBa < 0 || idxCl < 0) {
|
||||
std::cerr << "[CTRL DIAG] Ba/Cl indices not found: Ba=" << idxBa
|
||||
<< " Cl=" << idxCl << "\n";
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < this->species_names.size(); ++i) {
|
||||
double err_sum = 0.0;
|
||||
double sqr_err_sum = 0.0;
|
||||
uint32_t base_idx = i * size_per_prop;
|
||||
uint32_t nan_count = 0;
|
||||
uint32_t valid_count = 0;
|
||||
double ref_sum = 0.0, sur_sum = 0.0;
|
||||
|
||||
for (uint32_t j = 0; j < size_per_prop; ++j) {
|
||||
const double ref_value = reference_values[base_idx + j];
|
||||
const double sur_value = surrogate_values[base_idx + j];
|
||||
|
||||
if (std::isnan(ref_value) || std::isnan(sur_value)) {
|
||||
nan_count++;
|
||||
continue;
|
||||
}
|
||||
valid_count++;
|
||||
ref_sum += ref_value;
|
||||
sur_sum += sur_value;
|
||||
|
||||
if (ref_value == 0.0) {
|
||||
if (sur_value != 0.0) {
|
||||
err_sum += 1.0;
|
||||
@ -143,10 +186,45 @@ void poet::ControlModule::computeSpeciesErrors(
|
||||
sqr_err_sum += alpha * alpha;
|
||||
}
|
||||
}
|
||||
// sample printing (keeps previous behavior: species 5 and 6)
|
||||
if (i == 5 || i == 6) {
|
||||
std::cerr << "[CTRL SAMPLE] species_index=" << i
|
||||
<< " name=" << this->species_names[i]
|
||||
<< " base_idx=" << base_idx << " nan_count=" << nan_count
|
||||
<< " valid_count=" << valid_count << std::endl;
|
||||
uint32_t N = std::min<uint32_t>(size_per_prop, 20u);
|
||||
std::cerr << "[CTRL SAMPLE] reference: ";
|
||||
for (uint32_t j = 0; j < N; ++j)
|
||||
std::cerr << reference_values[base_idx + j]
|
||||
<< (j + 1 == N ? "\n" : " ");
|
||||
std::cerr << "[CTRL SAMPLE] surrogate: ";
|
||||
for (uint32_t j = 0; j < N; ++j)
|
||||
std::cerr << surrogate_values[base_idx + j]
|
||||
<< (j + 1 == N ? "\n" : " ");
|
||||
}
|
||||
|
||||
species_error_stats.mape[i] = 100.0 * (err_sum / size_per_prop);
|
||||
species_error_stats.rrmse[i] =
|
||||
(size_per_prop > 0) ? std::sqrt(sqr_err_sum / size_per_prop) : 0.0;
|
||||
if (valid_count > 0) {
|
||||
species_error_stats.mape[i] = 100.0 * (err_sum / valid_count);
|
||||
species_error_stats.rrmse[i] = std::sqrt(sqr_err_sum / valid_count);
|
||||
} else {
|
||||
species_error_stats.mape[i] = 0.0;
|
||||
species_error_stats.rrmse[i] = 0.0;
|
||||
std::cerr << "[CTRL WARN] no valid samples for species " << i << " ("
|
||||
<< this->species_names[i] << "), setting errors to 0\n";
|
||||
}
|
||||
|
||||
// DEBUG: detailed diagnostics for Ba/Cl (or whichever indices)
|
||||
if (this->species_names[i] == "Ba" || this->species_names[i] == "Cl") {
|
||||
double mean_ref = (valid_count > 0) ? (ref_sum / valid_count) : 0.0;
|
||||
double mean_sur = (valid_count > 0) ? (sur_sum / valid_count) : 0.0;
|
||||
std::cerr << "[CTRL DIAG] species=" << this->species_names[i]
|
||||
<< " idx=" << i << " base_idx=" << base_idx
|
||||
<< " valid_count=" << valid_count << " nan_count=" << nan_count
|
||||
<< " err_sum=" << err_sum << " sqr_err_sum=" << sqr_err_sum
|
||||
<< " mean_ref=" << mean_ref << " mean_sur=" << mean_sur
|
||||
<< " computed_MAPE=" << species_error_stats.mape[i]
|
||||
<< " computed_RRMSE=" << species_error_stats.rrmse[i] << "\n";
|
||||
}
|
||||
}
|
||||
error_history.push_back(species_error_stats);
|
||||
}
|
||||
|
||||
@ -30,6 +30,8 @@ public:
|
||||
|
||||
void endIteration(const uint32_t iter);
|
||||
|
||||
void setChemistryModule(poet::ChemistryModule *c) { chem = c; }
|
||||
|
||||
// void BCastControlFlags();
|
||||
|
||||
//bool triggerRollbackIfExceeded(ChemistryModule &chem,
|
||||
|
||||
@ -650,8 +650,8 @@ int main(int argc, char *argv[]) {
|
||||
init_list.getChemistryInit(), MPI_COMM_WORLD);
|
||||
|
||||
ControlModule control;
|
||||
|
||||
chemistry.setControlModule(&control);
|
||||
control.setChemistryModule(&chemistry);
|
||||
|
||||
const ChemistryModule::SurrogateSetup surr_setup = {
|
||||
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user