refactor: Rework deferred R function evaluation

fix: Unique pointer behaviour of `global_rt_setup` was messed up
This commit is contained in:
Max Lübke 2024-06-12 09:37:36 +02:00 committed by Marco De Lucia
parent fec92ad3d3
commit 9122e51980
7 changed files with 76 additions and 57 deletions

View File

@ -1,17 +1,13 @@
#ifndef RPOET_H_ #pragma once
#define RPOET_H_
#include <RInside.h> #include <RInside.h>
#include <Rcpp.h> #include <Rcpp.h>
#include <Rinternals.h> #include <Rinternals.h>
#include <cstddef>
#include <exception> #include <exception>
#include <optional> #include <memory>
#include <stdexcept>
#include <string> #include <string>
#include <utility>
#include <vector>
namespace poet {
class RInsidePOET : public RInside { class RInsidePOET : public RInside {
public: public:
static RInsidePOET &getInstance() { static RInsidePOET &getInstance() {
@ -33,44 +29,64 @@ private:
RInsidePOET() : RInside(){}; RInsidePOET() : RInside(){};
}; };
template <typename T> class RHookFunction { /**
* @brief Deferred evaluation function
*
* The class is intended to call R functions within an existing RInside
* instance. The problem with "original" Rcpp::Function is that they require:
* 1. RInside instance already present, restricting the declaration of
* Rcpp::Functions in global scope
* 2. Require the function to be present. Otherwise, they will throw an
* exception.
* This class solves both problems by deferring the evaluation of the function
* until the constructor is called and evaluating whether the function is
* present or not, wihout throwing an exception.
*
* @tparam T Return type of the function
*/
class DEFunc {
public: public:
RHookFunction() {} DEFunc() {}
RHookFunction(RInside &R, const std::string &f_name) { DEFunc(const std::string &f_name) {
try { try {
this->func = Rcpp::Function(Rcpp::as<SEXP>(R.parseEval(f_name.c_str()))); this->func = std::make_shared<Rcpp::Function>(f_name);
} catch (const std::exception &e) { } catch (const std::exception &e) {
} }
} }
RHookFunction(SEXP f) { DEFunc(SEXP f) {
try { try {
this->func = Rcpp::Function(f); this->func = std::make_shared<Rcpp::Function>(f);
} catch (const std::exception &e) { } catch (const std::exception &e) {
} }
} }
template <typename... Args> T operator()(Args... args) const { template <typename... Args> SEXP operator()(Args... args) const {
if (func.has_value()) { if (func) {
return (Rcpp::as<T>(this->func.value()(args...))); return (*this->func)(args...);
} else { } else {
throw std::exception(); throw std::exception();
} }
} }
RHookFunction &operator=(const RHookFunction &rhs) { DEFunc &operator=(const DEFunc &rhs) {
this->func = rhs.func; this->func = rhs.func;
return *this; return *this;
} }
RHookFunction(const RHookFunction &rhs) { this->func = rhs.func; } DEFunc(const DEFunc &rhs) { this->func = rhs.func; }
bool isValid() const { return this->func.has_value(); } bool isValid() const { return static_cast<bool>(func); }
SEXP asSEXP() const { return Rcpp::as<SEXP>(this->func.value()); } SEXP asSEXP() const {
if (!func) {
return R_NilValue;
}
return Rcpp::as<SEXP>(*this->func.get());
}
private: private:
std::optional<Rcpp::Function> func; std::shared_ptr<Rcpp::Function> func;
}; };
#endif // RPOET_H_ } // namespace poet

View File

@ -25,6 +25,7 @@
#include "Init/InitialList.hpp" #include "Init/InitialList.hpp"
#include "Rounding.hpp" #include "Rounding.hpp"
#include <Rcpp/proxy/ProtectedProxy.h>
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
@ -267,7 +268,8 @@ LookupKey DHT_Wrapper::fuzzForDHT_R(const std::vector<double> &cell,
NamedVector<double> input_nv(this->output_names, cell); NamedVector<double> input_nv(this->output_names, cell);
const std::vector<double> eval_vec = hooks.dht_fuzz(input_nv); const std::vector<double> eval_vec =
Rcpp::as<std::vector<double>>(hooks.dht_fuzz(input_nv));
assert(eval_vec.size() == this->key_count); assert(eval_vec.size() == this->key_count);
LookupKey vecFuzz(this->key_count + 1, {.0}); LookupKey vecFuzz(this->key_count + 1, {.0});

View File

@ -9,6 +9,7 @@
#include "Rounding.hpp" #include "Rounding.hpp"
#include <Rcpp.h> #include <Rcpp.h>
#include <Rcpp/proxy/ProtectedProxy.h>
#include <Rinternals.h> #include <Rinternals.h>
#include <algorithm> #include <algorithm>
@ -94,7 +95,8 @@ void InterpolationModule::tryInterpolation(WorkPackage &work_package) {
if (hooks.interp_pre.isValid()) { if (hooks.interp_pre.isValid()) {
NamedVector<double> nv_in(this->out_names, work_package.input[wp_i]); NamedVector<double> nv_in(this->out_names, work_package.input[wp_i]);
auto rm_indices = hooks.interp_pre(nv_in, pht_result.in_values); std::vector<int> rm_indices = Rcpp::as<std::vector<int>>(
hooks.interp_pre(nv_in, pht_result.in_values));
pht_result.size -= rm_indices.size(); pht_result.size -= rm_indices.size();

View File

@ -215,10 +215,10 @@ private:
public: public:
struct ChemistryHookFunctions { struct ChemistryHookFunctions {
RHookFunction<bool> dht_fill; poet::DEFunc dht_fill;
RHookFunction<std::vector<double>> dht_fuzz; poet::DEFunc dht_fuzz;
RHookFunction<std::vector<std::size_t>> interp_pre; poet::DEFunc interp_pre;
RHookFunction<bool> interp_post; poet::DEFunc interp_post;
}; };
struct ChemistryInit { struct ChemistryInit {

View File

@ -4,7 +4,8 @@
** **
** Copyright (C) 2018-2022 Marco De Lucia, Max Luebke (GFZ Potsdam) ** Copyright (C) 2018-2022 Marco De Lucia, Max Luebke (GFZ Potsdam)
** **
** Copyright (C) 2023-2024 Max Luebke (University of Potsdam) ** Copyright (C) 2023-2024 Marco De Lucia (GFZ Potsdam), Max Luebke (University
** of Potsdam)
** **
** POET is free software; you can redistribute it and/or modify it under the ** POET is free software; you can redistribute it and/or modify it under the
** terms of the GNU General Public License as published by the Free Software ** terms of the GNU General Public License as published by the Free Software
@ -36,7 +37,6 @@
#include <cstdlib> #include <cstdlib>
#include <memory> #include <memory>
#include <mpi.h> #include <mpi.h>
#include <optional>
#include <string> #include <string>
#include "Base/argh.hpp" #include "Base/argh.hpp"
@ -54,21 +54,21 @@ static std::unique_ptr<Rcpp::List> global_rt_setup;
// we need some lazy evaluation, as we can't define the functions // we need some lazy evaluation, as we can't define the functions
// before the R runtime is initialized // before the R runtime is initialized
static std::optional<Rcpp::Function> master_init_R; static poet::DEFunc master_init_R;
static std::optional<Rcpp::Function> master_iteration_end_R; static poet::DEFunc master_iteration_end_R;
static std::optional<Rcpp::Function> store_setup_R; static poet::DEFunc store_setup_R;
static std::optional<Rcpp::Function> ReadRObj_R; static poet::DEFunc ReadRObj_R;
static std::optional<Rcpp::Function> SaveRObj_R; static poet::DEFunc SaveRObj_R;
static std::optional<Rcpp::Function> source_R; static poet::DEFunc source_R;
static void init_global_functions(RInside &R) { static void init_global_functions(RInside &R) {
R.parseEval(kin_r_library); R.parseEval(kin_r_library);
master_init_R = Rcpp::Function("master_init"); master_init_R = DEFunc("master_init");
master_iteration_end_R = Rcpp::Function("master_iteration_end"); master_iteration_end_R = DEFunc("master_iteration_end");
store_setup_R = Rcpp::Function("StoreSetup"); store_setup_R = DEFunc("StoreSetup");
source_R = Rcpp::Function("source"); source_R = DEFunc("source");
ReadRObj_R = Rcpp::Function("ReadRObj"); ReadRObj_R = DEFunc("ReadRObj");
SaveRObj_R = Rcpp::Function("SaveRObj"); SaveRObj_R = DEFunc("SaveRObj");
} }
// HACK: this is a step back as the order and also the count of fields is // HACK: this is a step back as the order and also the count of fields is
@ -224,12 +224,12 @@ ParseRet parseInitValues(char **argv, RuntimeParameters &params) {
// Rcpp::Function ReadRObj("ReadRObj"); // Rcpp::Function ReadRObj("ReadRObj");
// Rcpp::Function SaveRObj("SaveRObj"); // Rcpp::Function SaveRObj("SaveRObj");
Rcpp::List init_params_ = ReadRObj_R.value()(init_file); Rcpp::List init_params_(ReadRObj_R(init_file));
params.init_params = init_params_; params.init_params = init_params_;
global_rt_setup = std::make_unique<Rcpp::List>(); global_rt_setup = std::make_unique<Rcpp::List>(
*global_rt_setup = source_R.value()(runtime_file, Rcpp::Named("local", true)); source_R(runtime_file, Rcpp::Named("local", true)));
*global_rt_setup = global_rt_setup->operator[]("value"); *global_rt_setup = (*global_rt_setup)["value"];
// MDL add "out_ext" for output format to R setup // MDL add "out_ext" for output format to R setup
(*global_rt_setup)["out_ext"] = params.out_ext; (*global_rt_setup)["out_ext"] = params.out_ext;
@ -524,9 +524,8 @@ int main(int argc, char *argv[]) {
// R.parseEvalQ("mysetup <- setup"); // R.parseEvalQ("mysetup <- setup");
// // if (MY_RANK == 0) { // get timestep vector from // // if (MY_RANK == 0) { // get timestep vector from
// // grid_init function ... // // // grid_init function ... //
*global_rt_setup = *global_rt_setup = master_init_R(*global_rt_setup, run_params.out_dir,
master_init_R.value()(*global_rt_setup, run_params.out_dir, init_list.getInitialGrid().asSEXP());
init_list.getInitialGrid().asSEXP());
// MDL: store all parameters // MDL: store all parameters
// MSG("Calling R Function to store calling parameters"); // MSG("Calling R Function to store calling parameters");
// R.parseEvalQ("StoreSetup(setup=mysetup)"); // R.parseEvalQ("StoreSetup(setup=mysetup)");

View File

@ -89,14 +89,14 @@ TEST_CASE("Field") {
} }
SUBCASE("Apply R function (set Na to zero)") { SUBCASE("Apply R function (set Na to zero)") {
RHookFunction<Field> to_call(R, "simple_field"); poet::DEFunc to_call("simple_field");
Field field_proc = to_call(dut.asSEXP()); Field field_proc = to_call(dut.asSEXP());
CHECK_EQ(field_proc["Na"], FieldColumn(dut.GetRequestedVecSize(), 0)); CHECK_EQ(field_proc["Na"], FieldColumn(dut.GetRequestedVecSize(), 0));
} }
SUBCASE("Apply R function (add two fields)") { SUBCASE("Apply R function (add two fields)") {
RHookFunction<Field> to_call(R, "extended_field"); poet::DEFunc to_call("extended_field");
Field field_proc = to_call(dut.asSEXP(), dut.asSEXP()); Field field_proc = to_call(dut.asSEXP(), dut.asSEXP());
CHECK_EQ(field_proc["Na"], CHECK_EQ(field_proc["Na"],

View File

@ -9,7 +9,7 @@
#include "testDataStructures.hpp" #include "testDataStructures.hpp"
TEST_CASE("NamedVector") { TEST_CASE("NamedVector") {
RInsidePOET &R = RInsidePOET::getInstance(); poet::RInsidePOET &R = poet::RInsidePOET::getInstance();
R["sourcefile"] = RInside_source_file; R["sourcefile"] = RInside_source_file;
R.parseEval("source(sourcefile)"); R.parseEval("source(sourcefile)");
@ -36,14 +36,14 @@ TEST_CASE("NamedVector") {
} }
SUBCASE("Apply R function (set to zero)") { SUBCASE("Apply R function (set to zero)") {
RHookFunction<poet::NamedVector<double>> to_call(R, "simple_named_vec"); poet::DEFunc to_call("simple_named_vec");
nv = to_call(nv); nv = to_call(nv);
CHECK_EQ(nv[2], 0); CHECK_EQ(nv[2], 0);
} }
SUBCASE("Apply R function (second NamedVector)") { SUBCASE("Apply R function (second NamedVector)") {
RHookFunction<poet::NamedVector<double>> to_call(R, "extended_named_vec"); poet::DEFunc to_call("extended_named_vec");
const std::vector<std::string> names{{"C", "H", "Mg"}}; const std::vector<std::string> names{{"C", "H", "Mg"}};
const std::vector<double> values{{0, 1, 2}}; const std::vector<double> values{{0, 1, 2}};
@ -56,8 +56,8 @@ TEST_CASE("NamedVector") {
} }
SUBCASE("Apply R function (check if zero)") { SUBCASE("Apply R function (check if zero)") {
RHookFunction<bool> to_call(R, "bool_named_vec"); poet::DEFunc to_call("bool_named_vec");
CHECK_FALSE(to_call(nv)); CHECK_FALSE(Rcpp::as<bool>(to_call(nv)));
} }
} }