Added Control Module

This commit is contained in:
rastogi 2025-10-24 13:01:41 +02:00
parent f9e7793635
commit c4d1682b3a
7 changed files with 142 additions and 76 deletions

Binary file not shown.

View File

@ -283,17 +283,6 @@ inline void poet::ChemistryModule::MasterSendPkgs(
MPI_Send(send_buffer.data(), send_buffer.size(), MPI_DOUBLE, p + 1,
LOOP_WORK, this->group_comm);
/* ---- DEBUG LOG (Sender side) ---- */
std::cout << "[DEBUG][rank=" << p+1
<< "] sending WP " << (count_pkgs - 1)
<< " to worker rank " << (p + 1)
<< " | len=" << send_buffer.size()
<< " | start index=" << wp_start_index
<< " | second element=" << send_buffer[1]
<< " | pkg size=" << local_work_package_size
<< std::endl;
/* -------------------------------- */
/* Mark that worker has work to do */
w_list[p].has_work = 1;
free_workers--;
@ -466,7 +455,8 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
shuffleField(chem_field.AsVector(), this->n_cells, this->prop_count,
wp_sizes_vector.size());
std::vector<double> mpi_surr_buffer{mpi_buffer};
std::vector<double> mpi_surr_buffer;
mpi_surr_buffer.resize(mpi_buffer.size());
/* setup local variables */
pkg_to_send = wp_sizes_vector.size();

View File

@ -39,6 +39,7 @@ void poet::ChemistryModule::WorkerLoop() {
// each CHEM_ITER_END message
uint32_t iteration = 1;
bool loop = true;
while (loop) {
int func_type;
PropagateFunctionType(func_type);
@ -59,7 +60,7 @@ void poet::ChemistryModule::WorkerLoop() {
break;
}
case CHEM_CTRL: {
int control_flag = 0;
int control_flag ;
ChemBCast(&control_flag, 1, MPI_INT);
this->control_enabled = (control_flag == 1);
break;
@ -143,7 +144,6 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
/* decrement count of work_package by BUFFER_OFFSET */
count -= BUFFER_OFFSET;
/* check for changes on all additional variables given by the 'header' of
* mpi_buffer */
@ -162,12 +162,6 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
// current work package start location in field
wp_start_index = mpi_buffer[count + 4];
std::cout << "[DEBUG][rank=" << this->comm_rank << "] WP " << counter
<< " len=" << count << " | second element: " << mpi_buffer[1]
<< " | iteration=" << iteration << " | dt=" << dt
<< " | simtime=" << current_sim_time
<< " | start_index=" << wp_start_index << std::endl;
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
s_curr_wp.input[wp_i] =
std::vector<double>(mpi_buffer.begin() + this->prop_count * wp_i,
@ -207,7 +201,8 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
if (control_enabled) {
for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) {
s_curr_wp_control.output[wp_i] = std::vector<double>(prop_count, 0.0);
s_curr_wp_control.output[wp_i] =
std::vector<double>(this->prop_count, 0.0);
s_curr_wp_control.mapping[wp_i] = 0;
}
}
@ -219,8 +214,43 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
phreeqc_time_end = MPI_Wtime();
count =
packResultsIntoBuffer(mpi_buffer, count, s_curr_wp, s_curr_wp_control);
if (control_enabled) {
std::size_t sur_wp_offset = s_curr_wp.size * this->prop_count;
mpi_buffer.resize(count + sur_wp_offset);
for (std::size_t wp_i = 0; wp_i < s_curr_wp_control.size; wp_i++) {
std::copy(s_curr_wp_control.output[wp_i].begin(),
s_curr_wp_control.output[wp_i].end(),
mpi_buffer.begin() + this->prop_count * wp_i);
}
// s_curr_wp only contains the interpolated data
// copy surrogate output after the the pqc output, mpi_buffer[pqc][interp]
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
if (s_curr_wp.mapping[wp_i] !=
CHEM_PQC) // only copy if surrogate was used
{
std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(),
mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i);
} else {
// if pqc was used, copy pqc results again
std::copy(s_curr_wp_control.output[wp_i].begin(),
s_curr_wp_control.output[wp_i].end(),
mpi_buffer.begin() + sur_wp_offset + this->prop_count * wp_i);
}
}
count += sur_wp_offset;
} else {
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
std::copy(s_curr_wp.output[wp_i].begin(), s_curr_wp.output[wp_i].end(),
mpi_buffer.begin() + this->prop_count * wp_i);
}
}
/* send results to master */
MPI_Request send_req;
@ -245,40 +275,6 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
MPI_Wait(&send_req, MPI_STATUS_IGNORE);
}
int poet::ChemistryModule::packResultsIntoBuffer(
std::vector<double> &mpi_buffer, int base_count, const WorkPackage &wp,
const WorkPackage &wp_control) {
if (control_enabled) {
std::size_t wp_offset = wp_control.size * prop_count;
mpi_buffer.resize(base_count + wp_offset);
/* copy pqc outputs first */
for (std::size_t wp_i = 0; wp_i < wp_control.size; wp_i++) {
std::copy(wp_control.output[wp_i].begin(), wp_control.output[wp_i].end(),
mpi_buffer.begin() + prop_count * wp_i);
}
/* copy surrogate output, only if it contains interpolated data, after the
* the pqc output, layout = mpi_buffer[pqc][interp] */
for (std::size_t wp_i = 0; wp_i < wp.size; wp_i++) {
const auto &wp_copy = wp.mapping[wp_i] != CHEM_PQC
? wp.output[wp_i]
: wp_control.output[wp_i];
std::copy(wp_copy.begin(), wp_copy.end(),
mpi_buffer.begin() + wp_offset + prop_count * wp_i);
}
return base_count + static_cast<int>(wp_offset);
} else {
for (std::size_t wp_i = 0; wp_i < wp.size; wp_i++) {
std::copy(wp.output[wp_i].begin(), wp.output[wp_i].end(),
mpi_buffer.begin() + prop_count + wp_i);
}
return base_count;
}
}
void poet::ChemistryModule::WorkerPostIter(MPI_Status &prope_status,
uint32_t iteration) {
MPI_Recv(NULL, 0, MPI_DOUBLE, 0, LOOP_END, this->group_comm,

View File

@ -42,6 +42,9 @@ void poet::ControlModule::endIteration(const uint32_t iter) {
/* Control Logic*/
if (control_interval_enabled &&
checkpoint_interval > 0 /*&& !rollback_enabled*/) {
if (!chem) {
MSG("chem pointer is null — skipping checkpoint/stats write");
} else {
MSG("Writing checkpoint of iteration " + std::to_string(iter));
write_checkpoint(out_dir, "checkpoint" + std::to_string(iter) + ".hdf5",
{.field = chem->getField(), .iteration = iter});
@ -60,6 +63,7 @@ void poet::ControlModule::endIteration(const uint32_t iter) {
*/
}
}
}
/*
void poet::ControlModule::BCastControlFlags() {
@ -97,8 +101,8 @@ bool poet::ControlModule::triggerRollbackIfExceeded(ChemistryModule &chem,
MSG("[THRESHOLD EXCEEDED] " + props[i] +
" has MAPE = " + std::to_string(mape[i]) +
" exceeding threshold = " + std::to_string(params.mape_threshold[i]) +
" → rolling back to iteration " + std::to_string(rollback_iter));
" exceeding threshold = " + std::to_string(params.mape_threshold[i])
+ " → rolling back to iteration " + std::to_string(rollback_iter));
Checkpoint_s checkpoint_read{.field = chem.getField()};
read_checkpoint(params.out_dir,
@ -122,15 +126,54 @@ void poet::ControlModule::computeSpeciesErrors(
global_iteration,
/*rollback_counter*/ 0);
if (reference_values.size() != surrogate_values.size()) {
MSG(" Reference and surrogate vectors differ in size: " +
std::to_string(reference_values.size()) + " vs " +
std::to_string(surrogate_values.size()));
return;
}
const std::size_t expected =
static_cast<std::size_t>(this->species_names.size()) * size_per_prop;
if (reference_values.size() < expected) {
std::cerr << "[CTRL ERROR] input vectors too small: expected >= "
<< expected << " entries, got " << reference_values.size()
<< "\n";
return;
}
int idxBa = -1, idxCl = -1;
for (size_t k = 0; k < this->species_names.size(); ++k) {
if (this->species_names[k] == "Ba")
idxBa = (int)k;
if (this->species_names[k] == "Cl")
idxCl = (int)k;
}
if (idxBa < 0 || idxCl < 0) {
std::cerr << "[CTRL DIAG] Ba/Cl indices not found: Ba=" << idxBa
<< " Cl=" << idxCl << "\n";
}
for (uint32_t i = 0; i < this->species_names.size(); ++i) {
double err_sum = 0.0;
double sqr_err_sum = 0.0;
uint32_t base_idx = i * size_per_prop;
uint32_t nan_count = 0;
uint32_t valid_count = 0;
double ref_sum = 0.0, sur_sum = 0.0;
for (uint32_t j = 0; j < size_per_prop; ++j) {
const double ref_value = reference_values[base_idx + j];
const double sur_value = surrogate_values[base_idx + j];
if (std::isnan(ref_value) || std::isnan(sur_value)) {
nan_count++;
continue;
}
valid_count++;
ref_sum += ref_value;
sur_sum += sur_value;
if (ref_value == 0.0) {
if (sur_value != 0.0) {
err_sum += 1.0;
@ -143,10 +186,45 @@ void poet::ControlModule::computeSpeciesErrors(
sqr_err_sum += alpha * alpha;
}
}
// sample printing (keeps previous behavior: species 5 and 6)
if (i == 5 || i == 6) {
std::cerr << "[CTRL SAMPLE] species_index=" << i
<< " name=" << this->species_names[i]
<< " base_idx=" << base_idx << " nan_count=" << nan_count
<< " valid_count=" << valid_count << std::endl;
uint32_t N = std::min<uint32_t>(size_per_prop, 20u);
std::cerr << "[CTRL SAMPLE] reference: ";
for (uint32_t j = 0; j < N; ++j)
std::cerr << reference_values[base_idx + j]
<< (j + 1 == N ? "\n" : " ");
std::cerr << "[CTRL SAMPLE] surrogate: ";
for (uint32_t j = 0; j < N; ++j)
std::cerr << surrogate_values[base_idx + j]
<< (j + 1 == N ? "\n" : " ");
}
species_error_stats.mape[i] = 100.0 * (err_sum / size_per_prop);
species_error_stats.rrmse[i] =
(size_per_prop > 0) ? std::sqrt(sqr_err_sum / size_per_prop) : 0.0;
if (valid_count > 0) {
species_error_stats.mape[i] = 100.0 * (err_sum / valid_count);
species_error_stats.rrmse[i] = std::sqrt(sqr_err_sum / valid_count);
} else {
species_error_stats.mape[i] = 0.0;
species_error_stats.rrmse[i] = 0.0;
std::cerr << "[CTRL WARN] no valid samples for species " << i << " ("
<< this->species_names[i] << "), setting errors to 0\n";
}
// DEBUG: detailed diagnostics for Ba/Cl (or whichever indices)
if (this->species_names[i] == "Ba" || this->species_names[i] == "Cl") {
double mean_ref = (valid_count > 0) ? (ref_sum / valid_count) : 0.0;
double mean_sur = (valid_count > 0) ? (sur_sum / valid_count) : 0.0;
std::cerr << "[CTRL DIAG] species=" << this->species_names[i]
<< " idx=" << i << " base_idx=" << base_idx
<< " valid_count=" << valid_count << " nan_count=" << nan_count
<< " err_sum=" << err_sum << " sqr_err_sum=" << sqr_err_sum
<< " mean_ref=" << mean_ref << " mean_sur=" << mean_sur
<< " computed_MAPE=" << species_error_stats.mape[i]
<< " computed_RRMSE=" << species_error_stats.rrmse[i] << "\n";
}
}
error_history.push_back(species_error_stats);
}

View File

@ -30,6 +30,8 @@ public:
void endIteration(const uint32_t iter);
void setChemistryModule(poet::ChemistryModule *c) { chem = c; }
// void BCastControlFlags();
//bool triggerRollbackIfExceeded(ChemistryModule &chem,

View File

@ -650,8 +650,8 @@ int main(int argc, char *argv[]) {
init_list.getChemistryInit(), MPI_COMM_WORLD);
ControlModule control;
chemistry.setControlModule(&control);
control.setChemistryModule(&chemistry);
const ChemistryModule::SurrogateSetup surr_setup = {
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),