Fix cell_ID access in DHT/interpolation and restructure control cell metric computation to use row-major layout

This commit is contained in:
rastogi 2025-11-07 15:55:26 +01:00
parent 23d0cc2dd8
commit f15f9049b8
7 changed files with 29 additions and 30 deletions

Binary file not shown.

View File

@ -520,7 +520,7 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
std::vector<std::vector<double>> surrogate_batch;
surrogate_batch.reserve(this->control_batch.size());
for (const auto &element : this->control_batch) {
for (const auto &element : this->control_batch) {
for (size_t i = 0; i < this->n_cells; i++) {
uint32_t curr_cell_id = mpi_buffer[this->prop_count * i];
@ -536,7 +536,8 @@ void poet::ChemistryModule::MasterRunParallel(double dt) {
}
metrics_a = MPI_Wtime();
control_module->computeSpeciesErrorMetrics(this->control_batch, surrogate_batch, 1);
control_module->computeSpeciesErrorMetrics(this->control_batch,
surrogate_batch, 1);
metrics_b = MPI_Wtime();
this->metrics_t += metrics_b - metrics_a;

View File

@ -133,7 +133,7 @@ void DHT_Wrapper::fillDHT(const WorkPackage &work_package) {
continue;
}
if (work_package.input[i][0] != 2) {
if (work_package.input[i][1] != 2) {
continue;
}

View File

@ -76,7 +76,7 @@ void InterpolationModule::tryInterpolation(WorkPackage &work_package) {
const auto dht_results = this->dht_instance.getDHTResults();
for (int wp_i = 0; wp_i < work_package.size; wp_i++) {
if (work_package.input[wp_i][0] != 2) {
if (work_package.input[wp_i][1] != 2) {
interp_result.status[wp_i] = INSUFFICIENT_DATA;
continue;
}
@ -122,7 +122,7 @@ void InterpolationModule::tryInterpolation(WorkPackage &work_package) {
this->pht->incrementReadCounter(roundKey(rounded_key));
#endif
const int cell_id = static_cast<int>(work_package.input[wp_i][0]);
const int cell_id = static_cast<int>(work_package.input[wp_i][1]);
if (!to_calc_cache.contains(cell_id)) {
const std::vector<std::int32_t> &to_calc = dht_instance.getKeyElements();

View File

@ -133,6 +133,8 @@ void poet::ChemistryModule::ProcessControlWorkPackage(
WorkerRunWorkPackage(control_wp, current_sim_time, dt);
phreeqc_end = MPI_Wtime();
std::cout << "PQC RAN" << std::endl;
timings.ctrl_phreeqc_t += phreeqc_end - phreeqc_start;
for (std::size_t wp_i = 0; wp_i < control_wp.size; wp_i++) {
@ -240,8 +242,11 @@ void poet::ChemistryModule::WorkerDoWork(MPI_Status &probe_status,
for (std::size_t wp_i = 0; wp_i < s_curr_wp.size; wp_i++) {
uint32_t cell_id = s_curr_wp.input[wp_i][0];
if (this->ctrl_cell_ids.find(cell_id) != this->ctrl_cell_ids.end() &&
s_curr_wp.mapping[wp_i] != CHEM_PQC) {
bool is_control_cell = this->ctrl_cell_ids.find(cell_id) != this->ctrl_cell_ids.end();
bool used_surrogate = s_curr_wp.mapping[wp_i] != CHEM_PQC;
if (is_control_cell && used_surrogate) {
control_batch.push_back(s_curr_wp.input[wp_i]);
control_cells_processed++;

View File

@ -56,10 +56,9 @@ void poet::ControlModule::applyControlLogic(DiffusionModule &diffusion,
rollback_enabled = true;
rollback_count++;
sur_disabled_counter = penalty_interval;
MSG("Interpolation disabled for the next " +
std::to_string(penalty_interval) + ".");
}
}
@ -138,16 +137,14 @@ void poet::ControlModule::computeSpeciesErrorMetrics(
return;
}
// Loop over species (rows in the data structure)
for (size_t species_idx = 0; species_idx < reference_values.size(); species_idx++) {
for (size_t row = 0; row < reference_values.size(); row++) {
double err_sum = 0.0;
double sqr_err_sum = 0.0;
uint32_t count = 0;
// Loop over control cells (columns in the data structure)
for (size_t cell_idx = 0; cell_idx < size_per_prop; cell_idx++) {
const double ref_value = reference_values[species_idx][cell_idx];
const double sur_value = surrogate_values[species_idx][cell_idx];
for (size_t col = 0; col < this->species_names.size(); col++) {
const double ref_value = reference_values[row][col];
const double sur_value = surrogate_values[row][col];
const double ZERO_ABS = 1e-13;
if (std::isnan(ref_value) || std::isnan(sur_value)) {
@ -160,26 +157,22 @@ void poet::ControlModule::computeSpeciesErrorMetrics(
sqr_err_sum += 1.0;
count++;
}
// Both zero: skip (don't increment count)
}
else {
// Both zero: skip
} else {
double alpha = 1.0 - (sur_value / ref_value);
err_sum += std::abs(alpha);
sqr_err_sum += alpha * alpha;
count++;
}
// Store metrics for this species after processing all cells
if (count > 0) {
metrics.mape[col] = 100.0 * (err_sum / size_per_prop);
metrics.rrmse[col] = std::sqrt(sqr_err_sum / size_per_prop);
} else {
metrics.mape[col] = 0.0;
metrics.rrmse[col] = 0.0;
}
}
// Store metrics for this species after processing all cells
if (count > 0) {
metrics.mape[species_idx] = 100.0 * (err_sum / size_per_prop);
metrics.rrmse[species_idx] = std::sqrt(sqr_err_sum / size_per_prop);
} else {
metrics.mape[species_idx] = 0.0;
metrics.rrmse[species_idx] = 0.0;
}
metricsHistory.push_back(metrics);
}
// Push metrics to history once after processing all species
metricsHistory.push_back(metrics);
}