feat: use CLI11 as argument parser

feat: improve poet_initializer
This commit is contained in:
Max Luebke 2024-09-16 10:26:13 +02:00
parent 1e14ba6d69
commit 0f6ff06c4a
7 changed files with 194 additions and 647 deletions

View File

@ -15,11 +15,11 @@ pages](https://naaice.git-pages.gfz-potsdam.de/poet).
The following external libraries are shipped with POET:
- **argh** - https://github.com/adishavit/argh (BSD license)
- **IPhreeqc** with patches from GFZ -
https://github.com/usgs-coupled/iphreeqc -
https://git.gfz-potsdam.de/naaice/iphreeqc
- **tug** - https://git.gfz-potsdam.de/naaice/tug
- **CLI11** - <https://github.com/CLIUtils/CLI11>
- **IPhreeqc** with patches from GFZ/UP -
<https://github.com/usgs-coupled/iphreeqc> -
<https://git.gfz-potsdam.de/naaice/iphreeqc>
- **tug** - <https://git.gfz-potsdam.de/naaice/tug>
## Installation

View File

@ -6,18 +6,19 @@ function(ADD_BENCH_TARGET TARGET POET_BENCH_LIST RT_FILES OUT_PATH)
foreach(BENCH_FILE ${${POET_BENCH_LIST}})
get_filename_component(BENCH_NAME ${BENCH_FILE} NAME_WE)
set(OUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/${BENCH_NAME}.qs)
set(OUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/${BENCH_NAME})
set(OUT_FILE_EXT ${OUT_FILE}.qs)
add_custom_command(
OUTPUT ${OUT_FILE}
COMMAND $<TARGET_FILE:poet_init> -o ${OUT_FILE} -s ${CMAKE_CURRENT_SOURCE_DIR}/${BENCH_FILE}
OUTPUT ${OUT_FILE_EXT}
COMMAND $<TARGET_FILE:poet_init> -n ${OUT_FILE} -s ${CMAKE_CURRENT_SOURCE_DIR}/${BENCH_FILE}
COMMENT "Running poet_init on ${BENCH_FILE}"
DEPENDS poet_init ${CMAKE_CURRENT_SOURCE_DIR}/${BENCH_FILE}
VERBATIM
COMMAND_EXPAND_LISTS
)
list(APPEND OUT_FILES_LIST ${OUT_FILE})
list(APPEND OUT_FILES_LIST ${OUT_FILE_EXT})
endforeach(BENCH_FILE ${${POET_BENCH_LIST}})

View File

@ -1,459 +0,0 @@
/*
** Copyright (c) 2016, Adi Shavit All rights reserved.
**
** Redistribution and use in source and binary forms, with or without
** modification, are permitted provided that the following conditions are met:
**
** * Redistributions of source code must retain the above copyright notice, this
** list of conditions and the following disclaimer. * Redistributions in
** binary form must reproduce the above copyright notice, this list of
** conditions and the following disclaimer in the documentation and/or other
** materials provided with the distribution. * Neither the name of nor the
** names of its contributors may be used to endorse or promote products
** derived from this software without specific prior written permission.
**
** THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
** AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
** IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
** ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
** LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
** CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
** SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
** INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
** CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
** ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
** POSSIBILITY OF SUCH DAMAGE.
*/
#pragma once
#include <algorithm>
#include <sstream>
#include <string>
#include <vector>
#include <set>
#include <map>
#include <cassert>
namespace argh
{
// Terminology:
// A command line is composed of 2 types of args:
// 1. Positional args, i.e. free standing values
// 2. Options: args beginning with '-'. We identify two kinds:
// 2.1: Flags: boolean options => (exist ? true : false)
// 2.2: Parameters: a name followed by a non-option value
#if !defined(__GNUC__) || (__GNUC__ >= 5)
using string_stream = std::istringstream;
#else
// Until GCC 5, istringstream did not have a move constructor.
// stringstream_proxy is used instead, as a workaround.
class stringstream_proxy
{
public:
stringstream_proxy() = default;
// Construct with a value.
stringstream_proxy(std::string const& value) :
stream_(value)
{}
// Copy constructor.
stringstream_proxy(const stringstream_proxy& other) :
stream_(other.stream_.str())
{
stream_.setstate(other.stream_.rdstate());
}
void setstate(std::ios_base::iostate state) { stream_.setstate(state); }
// Stream out the value of the parameter.
// If the conversion was not possible, the stream will enter the fail state,
// and operator bool will return false.
template<typename T>
stringstream_proxy& operator >> (T& thing)
{
stream_ >> thing;
return *this;
}
// Get the string value.
std::string str() const { return stream_.str(); }
std::stringbuf* rdbuf() const { return stream_.rdbuf(); }
// Check the state of the stream.
// False when the most recent stream operation failed
operator bool() const { return !!stream_; }
~stringstream_proxy() = default;
private:
std::istringstream stream_;
};
using string_stream = stringstream_proxy;
#endif
class parser
{
public:
enum Mode { PREFER_FLAG_FOR_UNREG_OPTION = 1 << 0,
PREFER_PARAM_FOR_UNREG_OPTION = 1 << 1,
NO_SPLIT_ON_EQUALSIGN = 1 << 2,
SINGLE_DASH_IS_MULTIFLAG = 1 << 3,
};
parser() = default;
parser(std::initializer_list<char const* const> pre_reg_names)
{ add_params(pre_reg_names); }
parser(const char* const argv[], int mode = PREFER_FLAG_FOR_UNREG_OPTION)
{ parse(argv, mode); }
parser(int argc, const char* const argv[], int mode = PREFER_FLAG_FOR_UNREG_OPTION)
{ parse(argc, argv, mode); }
void add_param(std::string const& name);
void add_params(std::initializer_list<char const* const> init_list);
void parse(const char* const argv[], int mode = PREFER_FLAG_FOR_UNREG_OPTION);
void parse(int argc, const char* const argv[], int mode = PREFER_FLAG_FOR_UNREG_OPTION);
std::multiset<std::string> const& flags() const { return flags_; }
std::map<std::string, std::string> const& params() const { return params_; }
std::vector<std::string> const& pos_args() const { return pos_args_; }
// begin() and end() for using range-for over positional args.
std::vector<std::string>::const_iterator begin() const { return pos_args_.cbegin(); }
std::vector<std::string>::const_iterator end() const { return pos_args_.cend(); }
size_t size() const { return pos_args_.size(); }
//////////////////////////////////////////////////////////////////////////
// Accessors
// flag (boolean) accessors: return true if the flag appeared, otherwise false.
bool operator[](std::string const& name) const;
// multiple flag (boolean) accessors: return true if at least one of the flag appeared, otherwise false.
bool operator[](std::initializer_list<char const* const> init_list) const;
// returns positional arg string by order. Like argv[] but without the options
std::string const& operator[](size_t ind) const;
// returns a std::istream that can be used to convert a positional arg to a typed value.
string_stream operator()(size_t ind) const;
// same as above, but with a default value in case the arg is missing (index out of range).
template<typename T>
string_stream operator()(size_t ind, T&& def_val) const;
// parameter accessors, give a name get an std::istream that can be used to convert to a typed value.
// call .str() on result to get as string
string_stream operator()(std::string const& name) const;
// accessor for a parameter with multiple names, give a list of names, get an std::istream that can be used to convert to a typed value.
// call .str() on result to get as string
// returns the first value in the list to be found.
string_stream operator()(std::initializer_list<char const* const> init_list) const;
// same as above, but with a default value in case the param was missing.
// Non-string def_val types must have an operator<<() (output stream operator)
// If T only has an input stream operator, pass the string version of the type as in "3" instead of 3.
template<typename T>
string_stream operator()(std::string const& name, T&& def_val) const;
// same as above but for a list of names. returns the first value to be found.
template<typename T>
string_stream operator()(std::initializer_list<char const* const> init_list, T&& def_val) const;
private:
string_stream bad_stream() const;
std::string trim_leading_dashes(std::string const& name) const;
bool is_number(std::string const& arg) const;
bool is_option(std::string const& arg) const;
bool got_flag(std::string const& name) const;
bool is_param(std::string const& name) const;
private:
std::vector<std::string> args_;
std::map<std::string, std::string> params_;
std::vector<std::string> pos_args_;
std::multiset<std::string> flags_;
std::set<std::string> registeredParams_;
std::string empty_;
};
//////////////////////////////////////////////////////////////////////////
inline void parser::parse(const char * const argv[], int mode)
{
int argc = 0;
for (auto argvp = argv; *argvp; ++argc, ++argvp);
parse(argc, argv, mode);
}
//////////////////////////////////////////////////////////////////////////
inline void parser::parse(int argc, const char* const argv[], int mode /*= PREFER_FLAG_FOR_UNREG_OPTION*/)
{
// convert to strings
args_.resize(argc);
std::transform(argv, argv + argc, args_.begin(), [](const char* const arg) { return arg; });
// parse line
for (auto i = 0u; i < args_.size(); ++i)
{
if (!is_option(args_[i]))
{
pos_args_.emplace_back(args_[i]);
continue;
}
auto name = trim_leading_dashes(args_[i]);
if (!(mode & NO_SPLIT_ON_EQUALSIGN))
{
auto equalPos = name.find('=');
if (equalPos != std::string::npos)
{
params_.insert({ name.substr(0, equalPos), name.substr(equalPos + 1) });
continue;
}
}
// if the option is unregistered and should be a multi-flag
if (1 == (args_[i].size() - name.size()) && // single dash
argh::parser::SINGLE_DASH_IS_MULTIFLAG & mode && // multi-flag mode
!is_param(name)) // unregistered
{
std::string keep_param;
if (!name.empty() && is_param(std::string(1ul, name.back()))) // last char is param
{
keep_param += name.back();
name.resize(name.size() - 1);
}
for (auto const& c : name)
{
flags_.emplace(std::string{ c });
}
if (!keep_param.empty())
{
name = keep_param;
}
else
{
continue; // do not consider other options for this arg
}
}
// any potential option will get as its value the next arg, unless that arg is an option too
// in that case it will be determined a flag.
if (i == args_.size() - 1 || is_option(args_[i + 1]))
{
flags_.emplace(name);
continue;
}
// if 'name' is a pre-registered option, then the next arg cannot be a free parameter to it is skipped
// otherwise we have 2 modes:
// PREFER_FLAG_FOR_UNREG_OPTION: a non-registered 'name' is determined a flag.
// The following value (the next arg) will be a free parameter.
//
// PREFER_PARAM_FOR_UNREG_OPTION: a non-registered 'name' is determined a parameter, the next arg
// will be the value of that option.
assert(!(mode & argh::parser::PREFER_FLAG_FOR_UNREG_OPTION)
|| !(mode & argh::parser::PREFER_PARAM_FOR_UNREG_OPTION));
bool preferParam = mode & argh::parser::PREFER_PARAM_FOR_UNREG_OPTION;
if (is_param(name) || preferParam)
{
params_.insert({ name, args_[i + 1] });
++i; // skip next value, it is not a free parameter
continue;
}
else
{
flags_.emplace(name);
}
};
}
//////////////////////////////////////////////////////////////////////////
inline string_stream parser::bad_stream() const
{
string_stream bad;
bad.setstate(std::ios_base::failbit);
return bad;
}
//////////////////////////////////////////////////////////////////////////
inline bool parser::is_number(std::string const& arg) const
{
// inefficient but simple way to determine if a string is a number (which can start with a '-')
std::istringstream istr(arg);
double number;
istr >> number;
return !(istr.fail() || istr.bad());
}
//////////////////////////////////////////////////////////////////////////
inline bool parser::is_option(std::string const& arg) const
{
assert(0 != arg.size());
if (is_number(arg))
return false;
return '-' == arg[0];
}
//////////////////////////////////////////////////////////////////////////
inline std::string parser::trim_leading_dashes(std::string const& name) const
{
auto pos = name.find_first_not_of('-');
return std::string::npos != pos ? name.substr(pos) : name;
}
//////////////////////////////////////////////////////////////////////////
inline bool argh::parser::got_flag(std::string const& name) const
{
return flags_.end() != flags_.find(trim_leading_dashes(name));
}
//////////////////////////////////////////////////////////////////////////
inline bool argh::parser::is_param(std::string const& name) const
{
return registeredParams_.count(name);
}
//////////////////////////////////////////////////////////////////////////
inline bool parser::operator[](std::string const& name) const
{
return got_flag(name);
}
//////////////////////////////////////////////////////////////////////////
inline bool parser::operator[](std::initializer_list<char const* const> init_list) const
{
return std::any_of(init_list.begin(), init_list.end(), [&](char const* const name) { return got_flag(name); });
}
//////////////////////////////////////////////////////////////////////////
inline std::string const& parser::operator[](size_t ind) const
{
if (ind < pos_args_.size())
return pos_args_[ind];
return empty_;
}
//////////////////////////////////////////////////////////////////////////
inline string_stream parser::operator()(std::string const& name) const
{
auto optIt = params_.find(trim_leading_dashes(name));
if (params_.end() != optIt)
return string_stream(optIt->second);
return bad_stream();
}
//////////////////////////////////////////////////////////////////////////
inline string_stream parser::operator()(std::initializer_list<char const* const> init_list) const
{
for (auto& name : init_list)
{
auto optIt = params_.find(trim_leading_dashes(name));
if (params_.end() != optIt)
return string_stream(optIt->second);
}
return bad_stream();
}
//////////////////////////////////////////////////////////////////////////
template<typename T>
string_stream parser::operator()(std::string const& name, T&& def_val) const
{
auto optIt = params_.find(trim_leading_dashes(name));
if (params_.end() != optIt)
return string_stream(optIt->second);
std::ostringstream ostr;
ostr << def_val;
return string_stream(ostr.str()); // use default
}
//////////////////////////////////////////////////////////////////////////
// same as above but for a list of names. returns the first value to be found.
template<typename T>
string_stream parser::operator()(std::initializer_list<char const* const> init_list, T&& def_val) const
{
for (auto& name : init_list)
{
auto optIt = params_.find(trim_leading_dashes(name));
if (params_.end() != optIt)
return string_stream(optIt->second);
}
std::ostringstream ostr;
ostr << def_val;
return string_stream(ostr.str()); // use default
}
//////////////////////////////////////////////////////////////////////////
inline string_stream parser::operator()(size_t ind) const
{
if (pos_args_.size() <= ind)
return bad_stream();
return string_stream(pos_args_[ind]);
}
//////////////////////////////////////////////////////////////////////////
template<typename T>
string_stream parser::operator()(size_t ind, T&& def_val) const
{
if (pos_args_.size() <= ind)
{
std::ostringstream ostr;
ostr << def_val;
return string_stream(ostr.str());
}
return string_stream(pos_args_[ind]);
}
//////////////////////////////////////////////////////////////////////////
inline void parser::add_param(std::string const& name)
{
registeredParams_.insert(trim_leading_dashes(name));
}
//////////////////////////////////////////////////////////////////////////
inline void parser::add_params(std::initializer_list<char const* const> init_list)
{
for (auto& name : init_list)
registeredParams_.insert(trim_leading_dashes(name));
}
}

View File

@ -27,6 +27,16 @@ target_link_libraries(
PUBLIC MPI::MPI_C
)
include(FetchContent)
FetchContent_Declare(
cli11
QUIET
GIT_REPOSITORY https://github.com/CLIUtils/CLI11.git
GIT_TAG v2.4.2
)
FetchContent_MakeAvailable(cli11)
# add_library(poetlib
# Base/Grid.cpp
# Base/SimParams.cpp
@ -75,11 +85,11 @@ file(READ "${PROJECT_SOURCE_DIR}/R_lib/ai_surrogate_model.R" R_AI_SURROGATE_LIB)
configure_file(poet.hpp.in poet.hpp @ONLY)
add_executable(poet poet.cpp)
target_link_libraries(poet PRIVATE POETLib MPI::MPI_C RRuntime)
target_link_libraries(poet PRIVATE POETLib MPI::MPI_C RRuntime CLI11::CLI11)
target_include_directories(poet PRIVATE "${CMAKE_CURRENT_BINARY_DIR}")
add_executable(poet_init initializer.cpp)
target_link_libraries(poet_init PRIVATE POETLib RRuntime)
target_link_libraries(poet_init PRIVATE POETLib RRuntime CLI11::CLI11)
target_include_directories(poet_init PRIVATE "${CMAKE_CURRENT_BINARY_DIR}")
install(TARGETS poet poet_init DESTINATION bin)

View File

@ -1,7 +1,7 @@
#include "Init/InitialList.hpp"
#include "poet.hpp"
#include "Base/argh.hpp"
#include <CLI/CLI.hpp>
#include <cstdlib>
@ -11,32 +11,39 @@
#include <string>
int main(int argc, char **argv) {
// pre-register expected parameters before calling `parse`
argh::parser cmdl({"-o", "--output"});
cmdl.parse(argc, argv);
if (cmdl[{"-h", "--help"}] || cmdl.pos_args().size() != 2) {
std::cout << "Usage: " << argv[0]
<< " [-o, --output output_file]"
<< " [-s, --setwd] "
<< " <PoetScript.R>"
<< std::endl;
return EXIT_SUCCESS;
}
// initialize RIinside
RInside R(argc, argv);
R.parseEvalQ(init_r_library);
R.parseEvalQ(kin_r_library);
std::string input_script = cmdl.pos_args()[1];
// parse command line arguments
CLI::App app{"POET Initializer - Poet R script to POET qs/rds converter"};
std::string input_script;
app.add_option("PoetScript.R", input_script, "POET R script to convert")
->required();
std::string output_file;
app.add_option("-n, --name", output_file,
"Name of the output file without file extension");
bool setwd;
app.add_flag("-s, --setwd", setwd,
"Set working directory to the directory of the directory of the "
"input script")
->default_val(false);
bool asRDS;
app.add_flag("-r, --rds", asRDS, "Save output as .rds file instead of .qs")
->default_val(false);
CLI11_PARSE(app, argc, argv);
// source the input script
std::string normalized_path_script;
std::string in_file_name;
std::string curr_path =
Rcpp::as<std::string>(Rcpp::Function("normalizePath")(Rcpp::wrap(".")));
try {
normalized_path_script =
Rcpp::as<std::string>(Rcpp::Function("normalizePath")(input_script));
@ -52,22 +59,20 @@ int main(int argc, char **argv) {
return EXIT_FAILURE;
}
std::string output_file;
// if output file is not specified, use the input file name
if (output_file.empty()) {
std::string curr_path =
Rcpp::as<std::string>(Rcpp::Function("normalizePath")(Rcpp::wrap(".")));
// MDL: some test to understand
// std::string output_ext = ".rds";
// if (cmdl["q"]) output_ext = ".qs";
// std::cout << "Ouptut ext: " << output_ext << " ; infile substr: "
// << in_file_name.substr(0, in_file_name.find_last_of('.')) << std::endl;
output_file = curr_path + "/" +
in_file_name.substr(0, in_file_name.find_last_of('.'));
}
// cmdl({"-o", "--output"},
// curr_path + "/" +
// in_file_name.substr(0, in_file_name.find_last_of('.')) + ".qs") >> output_file;
// append the correct file extension
output_file += asRDS ? ".rds" : ".qs";
cmdl({"-o", "--output"}) >> output_file;
if (cmdl[{"-s", "--setwd"}]) {
// set working directory to the directory of the input script
if (setwd) {
const std::string dir_path = Rcpp::as<std::string>(
Rcpp::Function("dirname")(normalized_path_script));

View File

@ -23,6 +23,7 @@
#include "Base/Macros.hpp"
#include "Base/RInsidePOET.hpp"
#include "CLI/CLI.hpp"
#include "Chemistry/ChemistryModule.hpp"
#include "DataStructures/Field.hpp"
#include "Init/InitialList.hpp"
@ -39,7 +40,7 @@
#include <mpi.h>
#include <string>
#include "Base/argh.hpp"
#include <CLI/CLI.hpp>
#include <poet.hpp>
#include <vector>
@ -86,82 +87,95 @@ static void init_global_functions(RInside &R) {
enum ParseRet { PARSER_OK, PARSER_ERROR, PARSER_HELP };
ParseRet parseInitValues(char **argv, RuntimeParameters &params) {
// initialize argh object
argh::parser cmdl(argv);
int parseInitValues(int argc, char **argv, RuntimeParameters &params) {
// if user asked for help
if (cmdl[{"help", "h"}]) {
if (MY_RANK == 0) {
MSG("Todo");
MSG("See README.md for further information.");
}
CLI::App app{"POET - Potsdam rEactive Transport simulator"};
return ParseRet::PARSER_HELP;
}
// if positional arguments are missing
if (!cmdl(3)) {
if (MY_RANK == 0) {
ERRMSG("POET needs 3 positional arguments: ");
ERRMSG("1) the R script defining your simulation runtime.");
ERRMSG("2) the initial .rds file generated by poet_init.");
ERRMSG("3) the directory prefix where to save results and profiling");
}
return ParseRet::PARSER_ERROR;
}
// parse flags and check for unknown
for (const auto &option : cmdl.flags()) {
if (!(flaglist.find(option) != flaglist.end())) {
if (MY_RANK == 0) {
ERRMSG("Unrecognized option: " + option);
ERRMSG("Make sure to use available options. Exiting!");
}
return ParseRet::PARSER_ERROR;
}
}
// parse parameters and check for unknown
for (const auto &option : cmdl.params()) {
if (!(paramlist.find(option.first) != paramlist.end())) {
if (MY_RANK == 0) {
ERRMSG("Unrecognized option: " + option.first);
ERRMSG("Make sure to use available options. Exiting!");
}
return ParseRet::PARSER_ERROR;
}
}
params.print_progressbar = cmdl[{"P", "progress"}];
app.add_flag("-P,--progress", params.print_progress,
"Print progress bar during chemical simulation");
/*Parse work package size*/
cmdl("work-package-size", CHEM_DEFAULT_WORK_PACKAGE_SIZE) >>
params.work_package_size;
app.add_option(
"-w,--work-package-size", params.work_package_size,
"Work package size to distribute to each worker for chemistry module")
->check(CLI::PositiveNumber)
->default_val(RuntimeParameters::WORK_PACKAGE_SIZE_DEFAULT);
/* Parse DHT arguments */
params.use_dht = cmdl["dht"];
params.use_interp = cmdl["interp"];
auto *dht_group = app.add_option_group("DHT", "DHT related options");
dht_group->add_flag("--dht", params.use_dht, "Enable DHT");
// cout << "CPP: DHT is " << ( dht_enabled ? "ON" : "OFF" ) << '\n';
cmdl("dht-size", CHEM_DHT_SIZE_PER_PROCESS_MB) >> params.dht_size;
dht_group
->add_option("--dht-size", params.dht_size,
"DHT size per process in Megabyte")
->check(CLI::PositiveNumber)
->default_val(RuntimeParameters::DHT_SIZE_DEFAULT);
// cout << "CPP: DHT size per process (Byte) = " << dht_size_per_process <<
// endl;
cmdl("dht-snaps", 0) >> params.dht_snaps;
dht_group->add_option(
"--dht-snaps", params.dht_snaps,
"Save snapshots of DHT to disk: \n0 = disabled (default)\n1 = After "
"simulation\n2 = After each iteration");
params.use_interp = cmdl["interp"];
cmdl("interp-size", 100) >> params.interp_size;
cmdl("interp-min", 5) >> params.interp_min_entries;
cmdl("interp-bucket-entries", 20) >> params.interp_bucket_entries;
auto *interp_group =
app.add_option_group("Interpolation", "Interpolation related options");
params.use_ai_surrogate = cmdl["ai-surrogate"];
interp_group->add_flag("--interp", params.use_interp, "Enable interpolation");
interp_group
->add_option("--interp-size", params.interp_size,
"Size of the interpolation table in Megabyte")
->check(CLI::PositiveNumber)
->default_val(RuntimeParameters::INTERP_SIZE_DEFAULT);
interp_group
->add_option("--interp-min", params.interp_min_entries,
"Minimum number of entries in the interpolation table")
->check(CLI::PositiveNumber)
->default_val(RuntimeParameters::INTERP_MIN_ENTRIES_DEFAULT);
interp_group
->add_option(
"--interp-bucket-entries", params.interp_bucket_entries,
"Maximum number of entries in each bucket of the interpolation table")
->check(CLI::PositiveNumber)
->default_val(RuntimeParameters::INTERP_BUCKET_ENTRIES_DEFAULT);
// MDL: optional flag "--qs" to switch to qsave()
params.out_ext = "rds";
if (cmdl["qs"]) {
params.out_ext = "qs";
app.add_flag("--ai-surrogate", params.use_ai_surrogate,
"Enable AI surrogate for chemistry module");
app.add_flag("--rds", params.as_rds,
"Save output as .rds file instead of .qs");
std::string init_file;
std::string runtime_file;
app.add_option("runtime_file", runtime_file,
"Runtime R script defining the simulation")
->required()
->check(CLI::ExistingFile);
app.add_option(
"init_file", init_file,
"Initial R script defining the simulation, produced by poet_init")
->required()
->check(CLI::ExistingFile);
app.add_option("out_dir", params.out_dir,
"Output directory of the simulation")
->required();
try {
app.parse(argc, argv);
} catch (const CLI::ParseError &e) {
app.exit(e);
return -1;
}
// set the output extension
params.out_ext = params.as_rds ? "rds" : "qs";
if (MY_RANK == 0) {
// MSG("Complete results storage is " + BOOL_PRINT(simparams.store_result));
MSG("Output format/extension is " + params.out_ext);
@ -191,14 +205,6 @@ ParseRet parseInitValues(char **argv, RuntimeParameters &params) {
std::to_string(params.interp_bucket_entries));
}
}
std::string init_file;
std::string runtime_file;
cmdl(1) >> runtime_file;
cmdl(2) >> init_file;
cmdl(3) >> params.out_dir;
// chem_params.dht_outdir = out_dir;
/* distribute information to R runtime */
@ -232,7 +238,7 @@ ParseRet parseInitValues(char **argv, RuntimeParameters &params) {
(*global_rt_setup)["out_ext"] = params.out_ext;
params.timesteps =
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("timesteps"));
Rcpp::as<std::vector<double>>(global_rt_setup->operator[]("timesteps"));
} catch (const std::exception &e) {
ERRMSG("Error while parsing R scripts: " + std::string(e.what()));
@ -269,7 +275,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
* master for the following loop) */
uint32_t maxiter = params.timesteps.size();
if (params.print_progressbar) {
if (params.print_progress) {
chem.setProgressBarPrintout(true);
}
R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps());
@ -297,9 +303,10 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
double ai_start_t = MPI_Wtime();
// Save current values from the tug field as predictor for the ai step
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
R.parseEval(std::string(
"predictors <- setNames(data.frame(matrix(TMP, nrow=" +
std::to_string(chem.getField().GetRequestedVecSize()) + ")), TMP_PROPS)"));
R.parseEval(
std::string("predictors <- setNames(data.frame(matrix(TMP, nrow=" +
std::to_string(chem.getField().GetRequestedVecSize()) +
")), TMP_PROPS)"));
R.parseEval("predictors <- predictors[ai_surrogate_species]");
// Apply preprocessing
@ -308,7 +315,8 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
// Predict
MSG("AI Predict");
R.parseEval("aipreds_scaled <- prediction_step(model, predictors_scaled)");
R.parseEval(
"aipreds_scaled <- prediction_step(model, predictors_scaled)");
// Apply postprocessing
MSG("AI Postprocesing");
@ -316,20 +324,22 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
// Validate prediction and write valid predictions to chem field
MSG("AI Validate");
R.parseEval("validity_vector <- validate_predictions(predictors, aipreds)");
R.parseEval(
"validity_vector <- validate_predictions(predictors, aipreds)");
MSG("AI Marking accepted");
chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector"));
MSG("AI TempField");
std::vector<std::vector<double>> RTempField = R.parseEval("set_valid_predictions(predictors,\
std::vector<std::vector<double>> RTempField =
R.parseEval("set_valid_predictions(predictors,\
aipreds,\
validity_vector)");
MSG("AI Set Field");
Field predictions_field = Field(R.parseEval("nrow(predictors)"),
RTempField,
R.parseEval("colnames(predictors)"));
Field predictions_field =
Field(R.parseEval("nrow(predictors)"), RTempField,
R.parseEval("colnames(predictors)"));
MSG("AI Update");
chem.getField().update(predictions_field);
@ -344,16 +354,18 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters &params,
double ai_start_t = MPI_Wtime();
R["TMP"] = Rcpp::wrap(chem.getField().AsVector());
R.parseEval(std::string(
"targets <- setNames(data.frame(matrix(TMP, nrow=" +
std::to_string(chem.getField().GetRequestedVecSize()) + ")), TMP_PROPS)"));
R.parseEval(
std::string("targets <- setNames(data.frame(matrix(TMP, nrow=" +
std::to_string(chem.getField().GetRequestedVecSize()) +
")), TMP_PROPS)"));
R.parseEval("targets <- targets[ai_surrogate_species]");
// TODO: Check how to get the correct columns
R.parseEval("target_scaled <- preprocess(targets)");
MSG("AI: incremental training");
R.parseEval("model <- training_step(model, predictors_scaled, target_scaled, validity_vector)");
R.parseEval("model <- training_step(model, predictors_scaled, "
"target_scaled, validity_vector)");
double ai_end_t = MPI_Wtime();
R["ai_training_time"] = ai_end_t - ai_start_t;
}
@ -462,7 +474,6 @@ std::vector<std::string> getSpeciesNames(const Field &&field, int root,
return species_names_out;
}
int main(int argc, char *argv[]) {
int world_size;
@ -478,20 +489,24 @@ int main(int argc, char *argv[]) {
MSG("Running POET version " + std::string(poet_version));
}
init_global_functions(R);
RuntimeParameters run_params;
switch (parseInitValues(argv, run_params)) {
case ParseRet::PARSER_ERROR:
case ParseRet::PARSER_HELP:
if (parseInitValues(argc, argv, run_params) != 0) {
MPI_Finalize();
return 0;
case ParseRet::PARSER_OK:
break;
}
// switch (parseInitValues(argc, argv, run_params)) {
// case ParseRet::PARSER_ERROR:
// case ParseRet::PARSER_HELP:
// MPI_Finalize();
// return 0;
// case ParseRet::PARSER_OK:
// break;
// }
InitialList init_list(R);
init_list.importList(run_params.init_params, MY_RANK != 0);
@ -513,8 +528,7 @@ int main(int argc, char *argv[]) {
run_params.interp_bucket_entries,
run_params.interp_size,
run_params.interp_min_entries,
run_params.use_ai_surrogate
};
run_params.use_ai_surrogate};
chemistry.masterEnableSurrogates(surr_setup);
@ -538,15 +552,17 @@ int main(int argc, char *argv[]) {
/* Incorporate ai surrogate from R */
R.parseEvalQ(ai_surrogate_r_library);
/* Use dht species for model input and output */
R["ai_surrogate_species"] = init_list.getChemistryInit().dht_species.getNames();
R["ai_surrogate_species"] =
init_list.getChemistryInit().dht_species.getNames();
const std::string ai_surrogate_input_script = init_list.getChemistryInit().ai_surrogate_input_script;
const std::string ai_surrogate_input_script =
init_list.getChemistryInit().ai_surrogate_input_script;
MSG("AI: sourcing user-provided script");
R.parseEvalQ(ai_surrogate_input_script);
MSG("AI: sourcing user-provided script");
R.parseEvalQ(ai_surrogate_input_script);
MSG("AI: initialize AI model");
R.parseEval("model <- initiate_model()");
R.parseEval("model <- initiate_model()");
R.parseEval("gpu_info()");
}
@ -568,8 +584,8 @@ int main(int argc, char *argv[]) {
R["setup$out_ext"] = run_params.out_ext;
string r_vis_code;
r_vis_code =
"SaveRObj(x = profiling, path = paste0(out_dir, '/timings.', setup$out_ext));";
r_vis_code = "SaveRObj(x = profiling, path = paste0(out_dir, "
"'/timings.', setup$out_ext));";
R.parseEval(r_vis_code);
MSG("Done! Results are stored as R objects into <" + run_params.out_dir +

View File

@ -35,63 +35,37 @@ static const char *poet_version = "@POET_VERSION@";
static const inline std::string kin_r_library = R"(@R_KIN_LIB@)";
static const inline std::string init_r_library = R"(@R_INIT_LIB@)";
static const inline std::string ai_surrogate_r_library = R"(@R_AI_SURROGATE_LIB@)";
static const inline std::string ai_surrogate_r_library =
R"(@R_AI_SURROGATE_LIB@)";
static const inline std::string r_runtime_parameters = "mysetup";
const std::set<std::string> flaglist{"ignore-result", "dht", "P", "progress",
"interp", "ai-surrogate", "qs"};
const std::set<std::string> paramlist{
"work-package-size", "dht-strategy", "dht-size", "dht-snaps",
"dht-file", "interp-size", "interp-min", "interp-bucket-entries"};
constexpr uint32_t CHEM_DEFAULT_WORK_PACKAGE_SIZE = 32;
constexpr uint32_t CHEM_DHT_SIZE_PER_PROCESS_MB = 1.5E3;
struct RuntimeParameters {
std::string out_dir;
std::vector<double> timesteps;
std::string out_ext; // MDL added to accomodate for qs::qsave/qread
bool print_progressbar;
uint32_t work_package_size;
Rcpp::List init_params;
bool use_dht;
bool as_rds = false;
std::string out_ext; // MDL added to accomodate for qs::qsave/qread
bool print_progress = false;
static constexpr std::uint32_t WORK_PACKAGE_SIZE_DEFAULT = 32;
std::uint32_t work_package_size;
bool use_dht = false;
static constexpr std::uint32_t DHT_SIZE_DEFAULT = 1.5E3;
std::uint32_t dht_size;
static constexpr std::uint8_t DHT_SNAPS_DEFAULT = 0;
std::uint8_t dht_snaps;
bool use_interp;
bool use_interp = false;
static constexpr std::uint32_t INTERP_SIZE_DEFAULT = 100;
std::uint32_t interp_size;
static constexpr std::uint32_t INTERP_MIN_ENTRIES_DEFAULT = 5;
std::uint32_t interp_min_entries;
static constexpr std::uint32_t INTERP_BUCKET_ENTRIES_DEFAULT = 20;
std::uint32_t interp_bucket_entries;
bool use_ai_surrogate;
struct ChemistryParams {
// std::string database_path;
// std::string input_script;
// bool use_dht;
// std::uint64_t dht_size;
// int dht_snaps;
// std::string dht_file;
// std::string dht_outdir;
// NamedVector<std::uint32_t> dht_signifs;
// bool use_interp;
// std::uint64_t pht_size;
// std::uint32_t pht_max_entries;
// std::uint32_t interp_min_entries;
// NamedVector<std::uint32_t> pht_signifs;
// struct Chem_Hook_Functions {
// RHookFunction<bool> dht_fill;
// RHookFunction<std::vector<double>> dht_fuzz;
// RHookFunction<std::vector<std::size_t>> interp_pre;
// RHookFunction<bool> interp_post;
// } hooks;
// void initFromR(RInsidePOET &R);
};
bool use_ai_surrogate = false;
};