Merge branch 'ml/fix-dht' into 'main'

fix: Error during DHT usage

See merge request naaice/poet!32
This commit is contained in:
Max Lübke 2024-08-29 08:47:45 +02:00
commit b037359db9
2 changed files with 47 additions and 2 deletions

@ -1 +1 @@
Subproject commit e6e5e0d5156c093241a53e6ce074ef346d64ae26
Subproject commit 48e65d87ad70f84aec01c27d9560cd3094a8129c

View File

@ -32,6 +32,7 @@
#include <Rcpp/DataFrame.h>
#include <Rcpp/Function.h>
#include <Rcpp/vector/instantiation.h>
#include <cstdint>
#include <cstdlib>
#include <memory>
#include <mpi.h>
@ -41,6 +42,7 @@
#include "Base/argh.hpp"
#include <poet.hpp>
#include <vector>
using namespace std;
using namespace poet;
@ -403,6 +405,49 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
return profiling;
}
std::vector<std::string> getSpeciesNames(const Field &&field, int root,
MPI_Comm comm) {
std::uint32_t n_elements;
std::uint32_t n_string_size;
int rank;
MPI_Comm_rank(comm, &rank);
const bool is_master = root == rank;
// first, the master sends all the species names iterative
if (is_master) {
n_elements = field.GetProps().size();
MPI_Bcast(&n_elements, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
for (std::uint32_t i = 0; i < n_elements; i++) {
n_string_size = field.GetProps()[i].size();
MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
MPI_Bcast(const_cast<char *>(field.GetProps()[i].c_str()), n_string_size,
MPI_CHAR, root, MPI_COMM_WORLD);
}
return field.GetProps();
}
// now all the worker stuff
MPI_Bcast(&n_elements, 1, MPI_UINT32_T, root, comm);
std::vector<std::string> species_names_out(n_elements);
for (std::uint32_t i = 0; i < n_elements; i++) {
MPI_Bcast(&n_string_size, 1, MPI_UINT32_T, root, MPI_COMM_WORLD);
char recv_buf[n_string_size];
MPI_Bcast(recv_buf, n_string_size, MPI_CHAR, root, MPI_COMM_WORLD);
species_names_out[i] = std::string(recv_buf, n_string_size);
}
return species_names_out;
}
int main(int argc, char *argv[]) {
int world_size;
@ -442,7 +487,7 @@ int main(int argc, char *argv[]) {
init_list.getChemistryInit(), MPI_COMM_WORLD);
const ChemistryModule::SurrogateSetup surr_setup = {
init_list.getInitialGrid().GetProps(),
getSpeciesNames(init_list.getInitialGrid(), 0, MPI_COMM_WORLD),
run_params.use_dht,
run_params.dht_size,
run_params.use_interp,