Update alpha matrix parsing

This commit is contained in:
Max Lübke 2024-03-19 14:07:25 +00:00
parent badb01b4fe
commit ebfa10c236
3 changed files with 50 additions and 65 deletions

View File

@ -18,7 +18,7 @@
namespace poet {
enum SEXP_TYPE { SEXP_IS_LIST, SEXP_IS_MAT, SEXP_IS_NUM };
enum SEXP_TYPE { SEXP_IS_LIST, SEXP_IS_VEC };
const std::map<std::uint8_t, std::string> tug_side_mapping = {
{tug::BC_SIDE_RIGHT, "E"},
@ -61,20 +61,38 @@ static Rcpp::List parseBoundaries2D(const Rcpp::List &boundaries_list,
static inline SEXP_TYPE get_datatype(const SEXP &input) {
Rcpp::Function check_list("is.list");
Rcpp::Function check_mat("is.matrix");
if (Rcpp::as<bool>(check_list(input))) {
return SEXP_IS_LIST;
} else if (Rcpp::as<bool>(check_mat(input))) {
return SEXP_IS_MAT;
} else {
return SEXP_IS_NUM;
return SEXP_IS_VEC;
}
}
static Rcpp::List parseAlphas2D(const SEXP &input,
const std::vector<std::string> &transport_names,
std::uint32_t n_cols, std::uint32_t n_rows) {
static std::vector<TugType> colMajToRowMaj(const Rcpp::NumericVector &vec,
std::uint32_t n_cols,
std::uint32_t n_rows) {
if (vec.size() == 1) {
return std::vector<TugType>(n_cols * n_rows, vec[0]);
} else {
if (vec.size() != n_cols * n_rows) {
throw std::runtime_error("Alpha matrix does not match grid dimensions");
}
std::vector<TugType> alpha(n_cols * n_rows);
for (std::uint32_t i = 0; i < n_cols; i++) {
for (std::uint32_t j = 0; j < n_rows; j++) {
alpha[i * n_rows + j] = vec[j * n_cols + i];
}
}
return alpha;
}
}
static Rcpp::List parseAlphas(const SEXP &input,
const std::vector<std::string> &transport_names,
std::uint32_t n_cols, std::uint32_t n_rows) {
Rcpp::List out_list;
SEXP_TYPE input_type = get_datatype(input);
@ -93,49 +111,17 @@ static Rcpp::List parseAlphas2D(const SEXP &input,
throw std::runtime_error("Alphas list does not contain transport name");
}
const Rcpp::NumericMatrix alpha_mat(input_list[name]);
const Rcpp::NumericVector &alpha_col_order_vec = input_list[name];
if (alpha_mat.size() == 1) {
Rcpp::NumericVector alpha(n_cols * n_rows, alpha_mat(0, 0));
out_list[name] = Rcpp::wrap(alpha);
} else {
if (alpha_mat.nrow() != n_rows || alpha_mat.ncol() != n_cols) {
throw std::runtime_error(
"Alpha matrix does not match grid dimensions");
}
out_list[name] = alpha_mat;
}
out_list[name] =
Rcpp::wrap(colMajToRowMaj(alpha_col_order_vec, n_cols, n_rows));
}
break;
}
case SEXP_IS_MAT: {
Rcpp::NumericMatrix input_mat(input);
Rcpp::NumericVector alpha(n_cols * n_rows, input_mat(0, 0));
if (input_mat.size() != 1) {
if (input_mat.nrow() != n_rows || input_mat.ncol() != n_cols) {
throw std::runtime_error("Alpha matrix does not match grid dimensions");
}
for (std::size_t i = 0; i < n_rows; i++) {
for (std::size_t j = 0; j < n_cols; j++) {
alpha[i * n_cols + j] = input_mat(i, j);
}
}
}
case SEXP_IS_VEC: {
const Rcpp::NumericVector alpha(input);
for (const auto &name : transport_names) {
out_list[name] = alpha;
}
break;
}
case SEXP_IS_NUM: {
Rcpp::NumericVector alpha(n_cols * n_rows, Rcpp::as<TugType>(input));
for (const auto &name : transport_names) {
out_list[name] = alpha;
out_list[name] = Rcpp::wrap(colMajToRowMaj(alpha, n_cols, n_rows));
}
break;
}
@ -146,24 +132,23 @@ static Rcpp::List parseAlphas2D(const SEXP &input,
return out_list;
}
void InitialList::initDiffusion(const Rcpp::List &diffusion_input) {
const Rcpp::List &boundaries =
diffusion_input[DIFFU_MEMBER_STR(DiffusionMembers::BOUNDARIES)];
const Rcpp::NumericVector &alpha_x =
// const Rcpp::List &boundaries =
// diffusion_input[DIFFU_MEMBER_STR(DiffusionMembers::BOUNDARIES)];
const SEXP &alpha_x =
diffusion_input[DIFFU_MEMBER_STR(DiffusionMembers::ALPHA_X)];
const Rcpp::NumericVector &alpha_y =
const SEXP &alpha_y =
diffusion_input[DIFFU_MEMBER_STR(DiffusionMembers::ALPHA_Y)];
std::vector<std::string> colnames =
Rcpp::as<std::vector<std::string>>(this->initial_grid.names());
std::vector<std::string> transport_names(colnames.begin() + 1,
colnames.begin() + 1 +
this->module_sizes[POET_SOL]);
this->alpha_x =
parseAlphas2D(alpha_x, transport_names, this->n_cols, this->n_rows);
parseAlphas(alpha_x, this->transport_names, this->n_cols, this->n_rows);
this->alpha_y =
parseAlphas2D(alpha_y, transport_names, this->n_cols, this->n_rows);
parseAlphas(alpha_y, this->transport_names, this->n_cols, this->n_rows);
R["alpha_x"] = this->alpha_x;
R["alpha_y"] = this->alpha_y;
R.parseEval("print(alpha_x)");
R.parseEval("print(alpha_y)");
}
} // namespace poet

View File

@ -160,7 +160,7 @@ void InitialList::initGrid(const Rcpp::List &grid_input) {
std::vector<std::string> colnames =
Rcpp::as<std::vector<std::string>>(this->initial_grid.names());
this->to_transport = this->pqc_sol_order = std::vector<std::string>(
this->transport_names = this->pqc_sol_order = std::vector<std::string>(
colnames.begin() + 1,
colnames.begin() + 1 + this->module_sizes[POET_SOL]);
@ -171,10 +171,10 @@ void InitialList::initGrid(const Rcpp::List &grid_input) {
this->pqc_raw_dumps = replaceRawKeywordIDs(phreeqc.raw_dumps());
R["pqc_mat"] = this->phreeqc_mat;
R["grid_def"] = initial_grid;
// R["pqc_mat"] = this->phreeqc_mat;
// R["grid_def"] = initial_grid;
R.parseEval("print(pqc_mat)");
R.parseEval("print(grid_def)");
// R.parseEval("print(pqc_mat)");
// R.parseEval("print(grid_def)");
}
} // namespace poet

View File

@ -98,7 +98,7 @@ private:
Rcpp::List alpha_x;
Rcpp::List alpha_y;
std::vector<std::string> to_transport;
std::vector<std::string> transport_names;
// Chemistry Members
static constexpr const char *chemistry_key = "Chemistry";