From 6ec98077d7a00a4537dcb457ea756b4121679994 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Max=20L=C3=BCbke?= Date: Tue, 19 Mar 2024 14:07:25 +0000 Subject: [PATCH] Update alpha matrix parsing --- src/Init/DiffusionInit.cpp | 103 ++++++++++++++++--------------------- src/Init/GridInit.cpp | 10 ++-- src/Init/InitialList.hpp | 2 +- 3 files changed, 50 insertions(+), 65 deletions(-) diff --git a/src/Init/DiffusionInit.cpp b/src/Init/DiffusionInit.cpp index c740091a5..a30b66ec6 100644 --- a/src/Init/DiffusionInit.cpp +++ b/src/Init/DiffusionInit.cpp @@ -18,7 +18,7 @@ namespace poet { -enum SEXP_TYPE { SEXP_IS_LIST, SEXP_IS_MAT, SEXP_IS_NUM }; +enum SEXP_TYPE { SEXP_IS_LIST, SEXP_IS_VEC }; const std::map tug_side_mapping = { {tug::BC_SIDE_RIGHT, "E"}, @@ -61,20 +61,38 @@ static Rcpp::List parseBoundaries2D(const Rcpp::List &boundaries_list, static inline SEXP_TYPE get_datatype(const SEXP &input) { Rcpp::Function check_list("is.list"); - Rcpp::Function check_mat("is.matrix"); if (Rcpp::as(check_list(input))) { return SEXP_IS_LIST; - } else if (Rcpp::as(check_mat(input))) { - return SEXP_IS_MAT; } else { - return SEXP_IS_NUM; + return SEXP_IS_VEC; } } -static Rcpp::List parseAlphas2D(const SEXP &input, - const std::vector &transport_names, - std::uint32_t n_cols, std::uint32_t n_rows) { +static std::vector colMajToRowMaj(const Rcpp::NumericVector &vec, + std::uint32_t n_cols, + std::uint32_t n_rows) { + if (vec.size() == 1) { + return std::vector(n_cols * n_rows, vec[0]); + } else { + if (vec.size() != n_cols * n_rows) { + throw std::runtime_error("Alpha matrix does not match grid dimensions"); + } + + std::vector alpha(n_cols * n_rows); + + for (std::uint32_t i = 0; i < n_cols; i++) { + for (std::uint32_t j = 0; j < n_rows; j++) { + alpha[i * n_rows + j] = vec[j * n_cols + i]; + } + } + return alpha; + } +} + +static Rcpp::List parseAlphas(const SEXP &input, + const std::vector &transport_names, + std::uint32_t n_cols, std::uint32_t n_rows) { Rcpp::List out_list; SEXP_TYPE input_type = get_datatype(input); @@ -93,49 +111,17 @@ static Rcpp::List parseAlphas2D(const SEXP &input, throw std::runtime_error("Alphas list does not contain transport name"); } - const Rcpp::NumericMatrix alpha_mat(input_list[name]); + const Rcpp::NumericVector &alpha_col_order_vec = input_list[name]; - if (alpha_mat.size() == 1) { - Rcpp::NumericVector alpha(n_cols * n_rows, alpha_mat(0, 0)); - out_list[name] = Rcpp::wrap(alpha); - } else { - if (alpha_mat.nrow() != n_rows || alpha_mat.ncol() != n_cols) { - throw std::runtime_error( - "Alpha matrix does not match grid dimensions"); - } - - out_list[name] = alpha_mat; - } + out_list[name] = + Rcpp::wrap(colMajToRowMaj(alpha_col_order_vec, n_cols, n_rows)); } break; } - case SEXP_IS_MAT: { - Rcpp::NumericMatrix input_mat(input); - - Rcpp::NumericVector alpha(n_cols * n_rows, input_mat(0, 0)); - - if (input_mat.size() != 1) { - if (input_mat.nrow() != n_rows || input_mat.ncol() != n_cols) { - throw std::runtime_error("Alpha matrix does not match grid dimensions"); - } - - for (std::size_t i = 0; i < n_rows; i++) { - for (std::size_t j = 0; j < n_cols; j++) { - alpha[i * n_cols + j] = input_mat(i, j); - } - } - } - + case SEXP_IS_VEC: { + const Rcpp::NumericVector alpha(input); for (const auto &name : transport_names) { - out_list[name] = alpha; - } - - break; - } - case SEXP_IS_NUM: { - Rcpp::NumericVector alpha(n_cols * n_rows, Rcpp::as(input)); - for (const auto &name : transport_names) { - out_list[name] = alpha; + out_list[name] = Rcpp::wrap(colMajToRowMaj(alpha, n_cols, n_rows)); } break; } @@ -146,24 +132,23 @@ static Rcpp::List parseAlphas2D(const SEXP &input, return out_list; } void InitialList::initDiffusion(const Rcpp::List &diffusion_input) { - const Rcpp::List &boundaries = - diffusion_input[DIFFU_MEMBER_STR(DiffusionMembers::BOUNDARIES)]; - const Rcpp::NumericVector &alpha_x = + // const Rcpp::List &boundaries = + // diffusion_input[DIFFU_MEMBER_STR(DiffusionMembers::BOUNDARIES)]; + const SEXP &alpha_x = diffusion_input[DIFFU_MEMBER_STR(DiffusionMembers::ALPHA_X)]; - const Rcpp::NumericVector &alpha_y = + const SEXP &alpha_y = diffusion_input[DIFFU_MEMBER_STR(DiffusionMembers::ALPHA_Y)]; - std::vector colnames = - Rcpp::as>(this->initial_grid.names()); - - std::vector transport_names(colnames.begin() + 1, - colnames.begin() + 1 + - this->module_sizes[POET_SOL]); - this->alpha_x = - parseAlphas2D(alpha_x, transport_names, this->n_cols, this->n_rows); + parseAlphas(alpha_x, this->transport_names, this->n_cols, this->n_rows); this->alpha_y = - parseAlphas2D(alpha_y, transport_names, this->n_cols, this->n_rows); + parseAlphas(alpha_y, this->transport_names, this->n_cols, this->n_rows); + + R["alpha_x"] = this->alpha_x; + R["alpha_y"] = this->alpha_y; + + R.parseEval("print(alpha_x)"); + R.parseEval("print(alpha_y)"); } } // namespace poet \ No newline at end of file diff --git a/src/Init/GridInit.cpp b/src/Init/GridInit.cpp index 8e5061c76..7e77a65d4 100644 --- a/src/Init/GridInit.cpp +++ b/src/Init/GridInit.cpp @@ -160,7 +160,7 @@ void InitialList::initGrid(const Rcpp::List &grid_input) { std::vector colnames = Rcpp::as>(this->initial_grid.names()); - this->to_transport = this->pqc_sol_order = std::vector( + this->transport_names = this->pqc_sol_order = std::vector( colnames.begin() + 1, colnames.begin() + 1 + this->module_sizes[POET_SOL]); @@ -171,10 +171,10 @@ void InitialList::initGrid(const Rcpp::List &grid_input) { this->pqc_raw_dumps = replaceRawKeywordIDs(phreeqc.raw_dumps()); - R["pqc_mat"] = this->phreeqc_mat; - R["grid_def"] = initial_grid; + // R["pqc_mat"] = this->phreeqc_mat; + // R["grid_def"] = initial_grid; - R.parseEval("print(pqc_mat)"); - R.parseEval("print(grid_def)"); + // R.parseEval("print(pqc_mat)"); + // R.parseEval("print(grid_def)"); } } // namespace poet \ No newline at end of file diff --git a/src/Init/InitialList.hpp b/src/Init/InitialList.hpp index bc4a4dafd..3c37f84ee 100644 --- a/src/Init/InitialList.hpp +++ b/src/Init/InitialList.hpp @@ -98,7 +98,7 @@ private: Rcpp::List alpha_x; Rcpp::List alpha_y; - std::vector to_transport; + std::vector transport_names; // Chemistry Members static constexpr const char *chemistry_key = "Chemistry";