fix: distribute species names across all processes

This commit is contained in:
Max Lübke 2024-08-29 08:35:11 +02:00
parent b125016dab
commit e25ebfffdb

View File

@ -32,6 +32,7 @@
#include <Rcpp/DataFrame.h>
#include <Rcpp/Function.h>
#include <Rcpp/vector/instantiation.h>
#include <cstdint>
#include <cstdlib>
#include <memory>
#include <mpi.h>
@ -41,6 +42,7 @@
#include "Base/argh.hpp"
#include <poet.hpp>
#include <vector>
using namespace std;
using namespace poet;
@ -403,6 +405,49 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
return profiling;
}
std::vector<std::string> getSpeciesNames(const Field &&field, int root,
MPI_Comm comm) {
std::uint32_t n_elements;
std::uint32_t n_string_size;
int rank;
MPI_Comm_rank(comm, &rank);
const bool is_master = root == rank;
// first, the master sends all the species names iterative
if (is_master) {
n_elements = field.GetProps().size();
MPI_Bcast(&n_elements, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
for (std::uint32_t i = 0; i < n_elements; i++) {
n_string_size = field.GetProps()[i].size();
MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
MPI_Bcast(const_cast<char *>(field.GetProps()[i].c_str()), n_string_size,
MPI_CHAR, root, MPI_COMM_WORLD);
}
return field.GetProps();
}
// now all the worker stuff
MPI_Bcast(&n_elements, 1, MPI_UINT32_T, root, comm);
std::vector<std::string> species_names_out(n_elements);
for (std::uint32_t i = 0; i < n_elements; i++) {
MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
char recv_buf[n_string_size];
MPI_Bcast(recv_buf, n_string_size, MPI_CHAR, root, MPI_COMM_WORLD);
species_names_out[i] = std::string(recv_buf, n_string_size);
}
return species_names_out;
}
int main(int argc, char *argv[]) {
int world_size;
@ -442,7 +487,7 @@ int main(int argc, char *argv[]) {
init_list.getChemistryInit(), MPI_COMM_WORLD);
const ChemistryModule::SurrogateSetup surr_setup = {
init_list.getInitialGrid().GetProps(),
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
run_params.use_dht,
run_params.dht_size,
run_params.use_interp,