refactor: Rework deferred R function evaluation

fix: Unique pointer behaviour of `global_rt_setup` was messed up
This commit is contained in:
Max Lübke 2024-06-12 09:37:36 +02:00 committed by Marco De Lucia
parent fec92ad3d3
commit 9122e51980
7 changed files with 76 additions and 57 deletions

View File

@ -1,17 +1,13 @@
#ifndef RPOET_H_
#define RPOET_H_
#pragma once
#include <RInside.h>
#include <Rcpp.h>
#include <Rinternals.h>
#include <cstddef>
#include <exception>
#include <optional>
#include <stdexcept>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace poet {
class RInsidePOET : public RInside {
public:
static RInsidePOET &getInstance() {
@ -33,44 +29,64 @@ private:
RInsidePOET() : RInside(){};
};
template <typename T> class RHookFunction {
/**
* @brief Deferred evaluation function
*
* The class is intended to call R functions within an existing RInside
* instance. The problem with "original" Rcpp::Function is that they require:
* 1. RInside instance already present, restricting the declaration of
* Rcpp::Functions in global scope
* 2. Require the function to be present. Otherwise, they will throw an
* exception.
* This class solves both problems by deferring the evaluation of the function
* until the constructor is called and evaluating whether the function is
* present or not, wihout throwing an exception.
*
* @tparam T Return type of the function
*/
class DEFunc {
public:
RHookFunction() {}
RHookFunction(RInside &R, const std::string &f_name) {
DEFunc() {}
DEFunc(const std::string &f_name) {
try {
this->func = Rcpp::Function(Rcpp::as<SEXP>(R.parseEval(f_name.c_str())));
this->func = std::make_shared<Rcpp::Function>(f_name);
} catch (const std::exception &e) {
}
}
RHookFunction(SEXP f) {
DEFunc(SEXP f) {
try {
this->func = Rcpp::Function(f);
this->func = std::make_shared<Rcpp::Function>(f);
} catch (const std::exception &e) {
}
}
template <typename... Args> T operator()(Args... args) const {
if (func.has_value()) {
return (Rcpp::as<T>(this->func.value()(args...)));
template <typename... Args> SEXP operator()(Args... args) const {
if (func) {
return (*this->func)(args...);
} else {
throw std::exception();
}
}
RHookFunction &operator=(const RHookFunction &rhs) {
DEFunc &operator=(const DEFunc &rhs) {
this->func = rhs.func;
return *this;
}
RHookFunction(const RHookFunction &rhs) { this->func = rhs.func; }
DEFunc(const DEFunc &rhs) { this->func = rhs.func; }
bool isValid() const { return this->func.has_value(); }
bool isValid() const { return static_cast<bool>(func); }
SEXP asSEXP() const { return Rcpp::as<SEXP>(this->func.value()); }
SEXP asSEXP() const {
if (!func) {
return R_NilValue;
}
return Rcpp::as<SEXP>(*this->func.get());
}
private:
std::optional<Rcpp::Function> func;
std::shared_ptr<Rcpp::Function> func;
};
#endif // RPOET_H_
} // namespace poet

View File

@ -25,6 +25,7 @@
#include "Init/InitialList.hpp"
#include "Rounding.hpp"
#include <Rcpp/proxy/ProtectedProxy.h>
#include <algorithm>
#include <cassert>
#include <cmath>
@ -267,7 +268,8 @@ LookupKey DHT_Wrapper::fuzzForDHT_R(const std::vector<double> &cell,
NamedVector<double> input_nv(this->output_names, cell);
const std::vector<double> eval_vec = hooks.dht_fuzz(input_nv);
const std::vector<double> eval_vec =
Rcpp::as<std::vector<double>>(hooks.dht_fuzz(input_nv));
assert(eval_vec.size() == this->key_count);
LookupKey vecFuzz(this->key_count + 1, {.0});

View File

@ -9,6 +9,7 @@
#include "Rounding.hpp"
#include <Rcpp.h>
#include <Rcpp/proxy/ProtectedProxy.h>
#include <Rinternals.h>
#include <algorithm>
@ -94,7 +95,8 @@ void InterpolationModule::tryInterpolation(WorkPackage &work_package) {
if (hooks.interp_pre.isValid()) {
NamedVector<double> nv_in(this->out_names, work_package.input[wp_i]);
auto rm_indices = hooks.interp_pre(nv_in, pht_result.in_values);
std::vector<int> rm_indices = Rcpp::as<std::vector<int>>(
hooks.interp_pre(nv_in, pht_result.in_values));
pht_result.size -= rm_indices.size();

View File

@ -215,10 +215,10 @@ private:
public:
struct ChemistryHookFunctions {
RHookFunction<bool> dht_fill;
RHookFunction<std::vector<double>> dht_fuzz;
RHookFunction<std::vector<std::size_t>> interp_pre;
RHookFunction<bool> interp_post;
poet::DEFunc dht_fill;
poet::DEFunc dht_fuzz;
poet::DEFunc interp_pre;
poet::DEFunc interp_post;
};
struct ChemistryInit {

View File

@ -4,7 +4,8 @@
**
** Copyright (C) 2018-2022 Marco De Lucia, Max Luebke (GFZ Potsdam)
**
** Copyright (C) 2023-2024 Max Luebke (University of Potsdam)
** Copyright (C) 2023-2024 Marco De Lucia (GFZ Potsdam), Max Luebke (University
** of Potsdam)
**
** POET is free software; you can redistribute it and/or modify it under the
** terms of the GNU General Public License as published by the Free Software
@ -36,7 +37,6 @@
#include <cstdlib>
#include <memory>
#include <mpi.h>
#include <optional>
#include <string>
#include "Base/argh.hpp"
@ -54,21 +54,21 @@ static std::unique_ptr<Rcpp::List> global_rt_setup;
// we need some lazy evaluation, as we can't define the functions
// before the R runtime is initialized
static std::optional<Rcpp::Function> master_init_R;
static std::optional<Rcpp::Function> master_iteration_end_R;
static std::optional<Rcpp::Function> store_setup_R;
static std::optional<Rcpp::Function> ReadRObj_R;
static std::optional<Rcpp::Function> SaveRObj_R;
static std::optional<Rcpp::Function> source_R;
static poet::DEFunc master_init_R;
static poet::DEFunc master_iteration_end_R;
static poet::DEFunc store_setup_R;
static poet::DEFunc ReadRObj_R;
static poet::DEFunc SaveRObj_R;
static poet::DEFunc source_R;
static void init_global_functions(RInside &R) {
R.parseEval(kin_r_library);
master_init_R = Rcpp::Function("master_init");
master_iteration_end_R = Rcpp::Function("master_iteration_end");
store_setup_R = Rcpp::Function("StoreSetup");
source_R = Rcpp::Function("source");
ReadRObj_R = Rcpp::Function("ReadRObj");
SaveRObj_R = Rcpp::Function("SaveRObj");
master_init_R = DEFunc("master_init");
master_iteration_end_R = DEFunc("master_iteration_end");
store_setup_R = DEFunc("StoreSetup");
source_R = DEFunc("source");
ReadRObj_R = DEFunc("ReadRObj");
SaveRObj_R = DEFunc("SaveRObj");
}
// HACK: this is a step back as the order and also the count of fields is
@ -224,12 +224,12 @@ ParseRet parseInitValues(char **argv, RuntimeParameters &params) {
// Rcpp::Function ReadRObj("ReadRObj");
// Rcpp::Function SaveRObj("SaveRObj");
Rcpp::List init_params_ = ReadRObj_R.value()(init_file);
Rcpp::List init_params_(ReadRObj_R(init_file));
params.init_params = init_params_;
global_rt_setup = std::make_unique<Rcpp::List>();
*global_rt_setup = source_R.value()(runtime_file, Rcpp::Named("local", true));
*global_rt_setup = global_rt_setup->operator[]("value");
global_rt_setup = std::make_unique<Rcpp::List>(
source_R(runtime_file, Rcpp::Named("local", true)));
*global_rt_setup = (*global_rt_setup)["value"];
// MDL add "out_ext" for output format to R setup
(*global_rt_setup)["out_ext"] = params.out_ext;
@ -524,9 +524,8 @@ int main(int argc, char *argv[]) {
// R.parseEvalQ("mysetup <- setup");
// // if (MY_RANK == 0) { // get timestep vector from
// // grid_init function ... //
*global_rt_setup =
master_init_R.value()(*global_rt_setup, run_params.out_dir,
init_list.getInitialGrid().asSEXP());
*global_rt_setup = master_init_R(*global_rt_setup, run_params.out_dir,
init_list.getInitialGrid().asSEXP());
// MDL: store all parameters
// MSG("Calling R Function to store calling parameters");
// R.parseEvalQ("StoreSetup(setup=mysetup)");

View File

@ -89,14 +89,14 @@ TEST_CASE("Field") {
}
SUBCASE("Apply R function (set Na to zero)") {
RHookFunction<Field> to_call(R, "simple_field");
poet::DEFunc to_call("simple_field");
Field field_proc = to_call(dut.asSEXP());
CHECK_EQ(field_proc["Na"], FieldColumn(dut.GetRequestedVecSize(), 0));
}
SUBCASE("Apply R function (add two fields)") {
RHookFunction<Field> to_call(R, "extended_field");
poet::DEFunc to_call("extended_field");
Field field_proc = to_call(dut.asSEXP(), dut.asSEXP());
CHECK_EQ(field_proc["Na"],

View File

@ -9,7 +9,7 @@
#include "testDataStructures.hpp"
TEST_CASE("NamedVector") {
RInsidePOET &R = RInsidePOET::getInstance();
poet::RInsidePOET &R = poet::RInsidePOET::getInstance();
R["sourcefile"] = RInside_source_file;
R.parseEval("source(sourcefile)");
@ -36,14 +36,14 @@ TEST_CASE("NamedVector") {
}
SUBCASE("Apply R function (set to zero)") {
RHookFunction<poet::NamedVector<double>> to_call(R, "simple_named_vec");
poet::DEFunc to_call("simple_named_vec");
nv = to_call(nv);
CHECK_EQ(nv[2], 0);
}
SUBCASE("Apply R function (second NamedVector)") {
RHookFunction<poet::NamedVector<double>> to_call(R, "extended_named_vec");
poet::DEFunc to_call("extended_named_vec");
const std::vector<std::string> names{{"C", "H", "Mg"}};
const std::vector<double> values{{0, 1, 2}};
@ -56,8 +56,8 @@ TEST_CASE("NamedVector") {
}
SUBCASE("Apply R function (check if zero)") {
RHookFunction<bool> to_call(R, "bool_named_vec");
poet::DEFunc to_call("bool_named_vec");
CHECK_FALSE(to_call(nv));
CHECK_FALSE(Rcpp::as<bool>(to_call(nv)));
}
}