From 0f6ff06c4a73e05d4131a012866b3a32c46a9199 Mon Sep 17 00:00:00 2001 From: Max Luebke Date: Mon, 16 Sep 2024 10:26:13 +0200 Subject: [PATCH] feat: use CLI11 as argument parser feat: improve poet_initializer --- README.md | 10 +- bench/CMakeLists.txt | 9 +- src/Base/argh.hpp | 459 ------------------------------------------- src/CMakeLists.txt | 14 +- src/initializer.cpp | 69 ++++--- src/poet.cpp | 216 ++++++++++---------- src/poet.hpp.in | 64 ++---- 7 files changed, 194 insertions(+), 647 deletions(-) delete mode 100644 src/Base/argh.hpp diff --git a/README.md b/README.md index 4078193fb..598117675 100644 --- a/README.md +++ b/README.md @@ -15,11 +15,11 @@ pages](https://naaice.git-pages.gfz-potsdam.de/poet). The following external libraries are shipped with POET: -- **argh** - https://github.com/adishavit/argh (BSD license) -- **IPhreeqc** with patches from GFZ - - https://github.com/usgs-coupled/iphreeqc - - https://git.gfz-potsdam.de/naaice/iphreeqc -- **tug** - https://git.gfz-potsdam.de/naaice/tug +- **CLI11** - +- **IPhreeqc** with patches from GFZ/UP - + - + +- **tug** - ## Installation diff --git a/bench/CMakeLists.txt b/bench/CMakeLists.txt index 20581b512..01dc43caf 100644 --- a/bench/CMakeLists.txt +++ b/bench/CMakeLists.txt @@ -6,18 +6,19 @@ function(ADD_BENCH_TARGET TARGET POET_BENCH_LIST RT_FILES OUT_PATH) foreach(BENCH_FILE ${${POET_BENCH_LIST}}) get_filename_component(BENCH_NAME ${BENCH_FILE} NAME_WE) - set(OUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/${BENCH_NAME}.qs) + set(OUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/${BENCH_NAME}) + set(OUT_FILE_EXT ${OUT_FILE}.qs) add_custom_command( - OUTPUT ${OUT_FILE} - COMMAND $ -o ${OUT_FILE} -s ${CMAKE_CURRENT_SOURCE_DIR}/${BENCH_FILE} + OUTPUT ${OUT_FILE_EXT} + COMMAND $ -n ${OUT_FILE} -s ${CMAKE_CURRENT_SOURCE_DIR}/${BENCH_FILE} COMMENT "Running poet_init on ${BENCH_FILE}" DEPENDS poet_init ${CMAKE_CURRENT_SOURCE_DIR}/${BENCH_FILE} VERBATIM COMMAND_EXPAND_LISTS ) - list(APPEND OUT_FILES_LIST ${OUT_FILE}) + list(APPEND OUT_FILES_LIST ${OUT_FILE_EXT}) endforeach(BENCH_FILE ${${POET_BENCH_LIST}}) diff --git a/src/Base/argh.hpp b/src/Base/argh.hpp deleted file mode 100644 index 6b5f1f9c0..000000000 --- a/src/Base/argh.hpp +++ /dev/null @@ -1,459 +0,0 @@ -/* -** Copyright (c) 2016, Adi Shavit All rights reserved. -** -** Redistribution and use in source and binary forms, with or without -** modification, are permitted provided that the following conditions are met: -** -** * Redistributions of source code must retain the above copyright notice, this -** list of conditions and the following disclaimer. * Redistributions in -** binary form must reproduce the above copyright notice, this list of -** conditions and the following disclaimer in the documentation and/or other -** materials provided with the distribution. * Neither the name of nor the -** names of its contributors may be used to endorse or promote products -** derived from this software without specific prior written permission. -** -** THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -** AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -** IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -** ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -** LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -** CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -** SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -** INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -** CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -** ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -** POSSIBILITY OF SUCH DAMAGE. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -namespace argh -{ - // Terminology: - // A command line is composed of 2 types of args: - // 1. Positional args, i.e. free standing values - // 2. Options: args beginning with '-'. We identify two kinds: - // 2.1: Flags: boolean options => (exist ? true : false) - // 2.2: Parameters: a name followed by a non-option value - -#if !defined(__GNUC__) || (__GNUC__ >= 5) - using string_stream = std::istringstream; -#else - // Until GCC 5, istringstream did not have a move constructor. - // stringstream_proxy is used instead, as a workaround. - class stringstream_proxy - { - public: - stringstream_proxy() = default; - - // Construct with a value. - stringstream_proxy(std::string const& value) : - stream_(value) - {} - - // Copy constructor. - stringstream_proxy(const stringstream_proxy& other) : - stream_(other.stream_.str()) - { - stream_.setstate(other.stream_.rdstate()); - } - - void setstate(std::ios_base::iostate state) { stream_.setstate(state); } - - // Stream out the value of the parameter. - // If the conversion was not possible, the stream will enter the fail state, - // and operator bool will return false. - template - stringstream_proxy& operator >> (T& thing) - { - stream_ >> thing; - return *this; - } - - - // Get the string value. - std::string str() const { return stream_.str(); } - - std::stringbuf* rdbuf() const { return stream_.rdbuf(); } - - // Check the state of the stream. - // False when the most recent stream operation failed - operator bool() const { return !!stream_; } - - ~stringstream_proxy() = default; - private: - std::istringstream stream_; - }; - using string_stream = stringstream_proxy; -#endif - - class parser - { - public: - enum Mode { PREFER_FLAG_FOR_UNREG_OPTION = 1 << 0, - PREFER_PARAM_FOR_UNREG_OPTION = 1 << 1, - NO_SPLIT_ON_EQUALSIGN = 1 << 2, - SINGLE_DASH_IS_MULTIFLAG = 1 << 3, - }; - - parser() = default; - - parser(std::initializer_list pre_reg_names) - { add_params(pre_reg_names); } - - parser(const char* const argv[], int mode = PREFER_FLAG_FOR_UNREG_OPTION) - { parse(argv, mode); } - - parser(int argc, const char* const argv[], int mode = PREFER_FLAG_FOR_UNREG_OPTION) - { parse(argc, argv, mode); } - - void add_param(std::string const& name); - void add_params(std::initializer_list init_list); - - void parse(const char* const argv[], int mode = PREFER_FLAG_FOR_UNREG_OPTION); - void parse(int argc, const char* const argv[], int mode = PREFER_FLAG_FOR_UNREG_OPTION); - - std::multiset const& flags() const { return flags_; } - std::map const& params() const { return params_; } - std::vector const& pos_args() const { return pos_args_; } - - // begin() and end() for using range-for over positional args. - std::vector::const_iterator begin() const { return pos_args_.cbegin(); } - std::vector::const_iterator end() const { return pos_args_.cend(); } - size_t size() const { return pos_args_.size(); } - - ////////////////////////////////////////////////////////////////////////// - // Accessors - - // flag (boolean) accessors: return true if the flag appeared, otherwise false. - bool operator[](std::string const& name) const; - - // multiple flag (boolean) accessors: return true if at least one of the flag appeared, otherwise false. - bool operator[](std::initializer_list init_list) const; - - // returns positional arg string by order. Like argv[] but without the options - std::string const& operator[](size_t ind) const; - - // returns a std::istream that can be used to convert a positional arg to a typed value. - string_stream operator()(size_t ind) const; - - // same as above, but with a default value in case the arg is missing (index out of range). - template - string_stream operator()(size_t ind, T&& def_val) const; - - // parameter accessors, give a name get an std::istream that can be used to convert to a typed value. - // call .str() on result to get as string - string_stream operator()(std::string const& name) const; - - // accessor for a parameter with multiple names, give a list of names, get an std::istream that can be used to convert to a typed value. - // call .str() on result to get as string - // returns the first value in the list to be found. - string_stream operator()(std::initializer_list init_list) const; - - // same as above, but with a default value in case the param was missing. - // Non-string def_val types must have an operator<<() (output stream operator) - // If T only has an input stream operator, pass the string version of the type as in "3" instead of 3. - template - string_stream operator()(std::string const& name, T&& def_val) const; - - // same as above but for a list of names. returns the first value to be found. - template - string_stream operator()(std::initializer_list init_list, T&& def_val) const; - - private: - string_stream bad_stream() const; - std::string trim_leading_dashes(std::string const& name) const; - bool is_number(std::string const& arg) const; - bool is_option(std::string const& arg) const; - bool got_flag(std::string const& name) const; - bool is_param(std::string const& name) const; - - private: - std::vector args_; - std::map params_; - std::vector pos_args_; - std::multiset flags_; - std::set registeredParams_; - std::string empty_; - }; - - - ////////////////////////////////////////////////////////////////////////// - - inline void parser::parse(const char * const argv[], int mode) - { - int argc = 0; - for (auto argvp = argv; *argvp; ++argc, ++argvp); - parse(argc, argv, mode); - } - - ////////////////////////////////////////////////////////////////////////// - - inline void parser::parse(int argc, const char* const argv[], int mode /*= PREFER_FLAG_FOR_UNREG_OPTION*/) - { - // convert to strings - args_.resize(argc); - std::transform(argv, argv + argc, args_.begin(), [](const char* const arg) { return arg; }); - - // parse line - for (auto i = 0u; i < args_.size(); ++i) - { - if (!is_option(args_[i])) - { - pos_args_.emplace_back(args_[i]); - continue; - } - - auto name = trim_leading_dashes(args_[i]); - - if (!(mode & NO_SPLIT_ON_EQUALSIGN)) - { - auto equalPos = name.find('='); - if (equalPos != std::string::npos) - { - params_.insert({ name.substr(0, equalPos), name.substr(equalPos + 1) }); - continue; - } - } - - // if the option is unregistered and should be a multi-flag - if (1 == (args_[i].size() - name.size()) && // single dash - argh::parser::SINGLE_DASH_IS_MULTIFLAG & mode && // multi-flag mode - !is_param(name)) // unregistered - { - std::string keep_param; - - if (!name.empty() && is_param(std::string(1ul, name.back()))) // last char is param - { - keep_param += name.back(); - name.resize(name.size() - 1); - } - - for (auto const& c : name) - { - flags_.emplace(std::string{ c }); - } - - if (!keep_param.empty()) - { - name = keep_param; - } - else - { - continue; // do not consider other options for this arg - } - } - - // any potential option will get as its value the next arg, unless that arg is an option too - // in that case it will be determined a flag. - if (i == args_.size() - 1 || is_option(args_[i + 1])) - { - flags_.emplace(name); - continue; - } - - // if 'name' is a pre-registered option, then the next arg cannot be a free parameter to it is skipped - // otherwise we have 2 modes: - // PREFER_FLAG_FOR_UNREG_OPTION: a non-registered 'name' is determined a flag. - // The following value (the next arg) will be a free parameter. - // - // PREFER_PARAM_FOR_UNREG_OPTION: a non-registered 'name' is determined a parameter, the next arg - // will be the value of that option. - - assert(!(mode & argh::parser::PREFER_FLAG_FOR_UNREG_OPTION) - || !(mode & argh::parser::PREFER_PARAM_FOR_UNREG_OPTION)); - - bool preferParam = mode & argh::parser::PREFER_PARAM_FOR_UNREG_OPTION; - - if (is_param(name) || preferParam) - { - params_.insert({ name, args_[i + 1] }); - ++i; // skip next value, it is not a free parameter - continue; - } - else - { - flags_.emplace(name); - } - }; - } - - ////////////////////////////////////////////////////////////////////////// - - inline string_stream parser::bad_stream() const - { - string_stream bad; - bad.setstate(std::ios_base::failbit); - return bad; - } - - ////////////////////////////////////////////////////////////////////////// - - inline bool parser::is_number(std::string const& arg) const - { - // inefficient but simple way to determine if a string is a number (which can start with a '-') - std::istringstream istr(arg); - double number; - istr >> number; - return !(istr.fail() || istr.bad()); - } - - ////////////////////////////////////////////////////////////////////////// - - inline bool parser::is_option(std::string const& arg) const - { - assert(0 != arg.size()); - if (is_number(arg)) - return false; - return '-' == arg[0]; - } - - ////////////////////////////////////////////////////////////////////////// - - inline std::string parser::trim_leading_dashes(std::string const& name) const - { - auto pos = name.find_first_not_of('-'); - return std::string::npos != pos ? name.substr(pos) : name; - } - - ////////////////////////////////////////////////////////////////////////// - - inline bool argh::parser::got_flag(std::string const& name) const - { - return flags_.end() != flags_.find(trim_leading_dashes(name)); - } - - ////////////////////////////////////////////////////////////////////////// - - inline bool argh::parser::is_param(std::string const& name) const - { - return registeredParams_.count(name); - } - - ////////////////////////////////////////////////////////////////////////// - - inline bool parser::operator[](std::string const& name) const - { - return got_flag(name); - } - - ////////////////////////////////////////////////////////////////////////// - - inline bool parser::operator[](std::initializer_list init_list) const - { - return std::any_of(init_list.begin(), init_list.end(), [&](char const* const name) { return got_flag(name); }); - } - - ////////////////////////////////////////////////////////////////////////// - - inline std::string const& parser::operator[](size_t ind) const - { - if (ind < pos_args_.size()) - return pos_args_[ind]; - return empty_; - } - - ////////////////////////////////////////////////////////////////////////// - - inline string_stream parser::operator()(std::string const& name) const - { - auto optIt = params_.find(trim_leading_dashes(name)); - if (params_.end() != optIt) - return string_stream(optIt->second); - return bad_stream(); - } - - ////////////////////////////////////////////////////////////////////////// - - inline string_stream parser::operator()(std::initializer_list init_list) const - { - for (auto& name : init_list) - { - auto optIt = params_.find(trim_leading_dashes(name)); - if (params_.end() != optIt) - return string_stream(optIt->second); - } - return bad_stream(); - } - - ////////////////////////////////////////////////////////////////////////// - - template - string_stream parser::operator()(std::string const& name, T&& def_val) const - { - auto optIt = params_.find(trim_leading_dashes(name)); - if (params_.end() != optIt) - return string_stream(optIt->second); - - std::ostringstream ostr; - ostr << def_val; - return string_stream(ostr.str()); // use default - } - - ////////////////////////////////////////////////////////////////////////// - - // same as above but for a list of names. returns the first value to be found. - template - string_stream parser::operator()(std::initializer_list init_list, T&& def_val) const - { - for (auto& name : init_list) - { - auto optIt = params_.find(trim_leading_dashes(name)); - if (params_.end() != optIt) - return string_stream(optIt->second); - } - std::ostringstream ostr; - ostr << def_val; - return string_stream(ostr.str()); // use default - } - - ////////////////////////////////////////////////////////////////////////// - - inline string_stream parser::operator()(size_t ind) const - { - if (pos_args_.size() <= ind) - return bad_stream(); - - return string_stream(pos_args_[ind]); - } - - ////////////////////////////////////////////////////////////////////////// - - template - string_stream parser::operator()(size_t ind, T&& def_val) const - { - if (pos_args_.size() <= ind) - { - std::ostringstream ostr; - ostr << def_val; - return string_stream(ostr.str()); - } - - return string_stream(pos_args_[ind]); - } - - ////////////////////////////////////////////////////////////////////////// - - inline void parser::add_param(std::string const& name) - { - registeredParams_.insert(trim_leading_dashes(name)); - } - - ////////////////////////////////////////////////////////////////////////// - - inline void parser::add_params(std::initializer_list init_list) - { - for (auto& name : init_list) - registeredParams_.insert(trim_leading_dashes(name)); - } -} - - diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 24c10c834..4b07f355a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -27,6 +27,16 @@ target_link_libraries( PUBLIC MPI::MPI_C ) +include(FetchContent) +FetchContent_Declare( + cli11 + QUIET + GIT_REPOSITORY https://github.com/CLIUtils/CLI11.git + GIT_TAG v2.4.2 +) + +FetchContent_MakeAvailable(cli11) + # add_library(poetlib # Base/Grid.cpp # Base/SimParams.cpp @@ -75,11 +85,11 @@ file(READ "${PROJECT_SOURCE_DIR}/R_lib/ai_surrogate_model.R" R_AI_SURROGATE_LIB) configure_file(poet.hpp.in poet.hpp @ONLY) add_executable(poet poet.cpp) -target_link_libraries(poet PRIVATE POETLib MPI::MPI_C RRuntime) +target_link_libraries(poet PRIVATE POETLib MPI::MPI_C RRuntime CLI11::CLI11) target_include_directories(poet PRIVATE "${CMAKE_CURRENT_BINARY_DIR}") add_executable(poet_init initializer.cpp) -target_link_libraries(poet_init PRIVATE POETLib RRuntime) +target_link_libraries(poet_init PRIVATE POETLib RRuntime CLI11::CLI11) target_include_directories(poet_init PRIVATE "${CMAKE_CURRENT_BINARY_DIR}") install(TARGETS poet poet_init DESTINATION bin) diff --git a/src/initializer.cpp b/src/initializer.cpp index 2c76e5420..d2e663931 100644 --- a/src/initializer.cpp +++ b/src/initializer.cpp @@ -1,7 +1,7 @@ #include "Init/InitialList.hpp" #include "poet.hpp" -#include "Base/argh.hpp" +#include #include @@ -11,32 +11,39 @@ #include int main(int argc, char **argv) { - - // pre-register expected parameters before calling `parse` - argh::parser cmdl({"-o", "--output"}); - cmdl.parse(argc, argv); - - if (cmdl[{"-h", "--help"}] || cmdl.pos_args().size() != 2) { - std::cout << "Usage: " << argv[0] - << " [-o, --output output_file]" - << " [-s, --setwd] " - << " " - << std::endl; - return EXIT_SUCCESS; - } - + // initialize RIinside RInside R(argc, argv); R.parseEvalQ(init_r_library); R.parseEvalQ(kin_r_library); - std::string input_script = cmdl.pos_args()[1]; + // parse command line arguments + CLI::App app{"POET Initializer - Poet R script to POET qs/rds converter"}; + + std::string input_script; + app.add_option("PoetScript.R", input_script, "POET R script to convert") + ->required(); + + std::string output_file; + app.add_option("-n, --name", output_file, + "Name of the output file without file extension"); + + bool setwd; + app.add_flag("-s, --setwd", setwd, + "Set working directory to the directory of the directory of the " + "input script") + ->default_val(false); + + bool asRDS; + app.add_flag("-r, --rds", asRDS, "Save output as .rds file instead of .qs") + ->default_val(false); + + CLI11_PARSE(app, argc, argv); + + // source the input script std::string normalized_path_script; std::string in_file_name; - std::string curr_path = - Rcpp::as(Rcpp::Function("normalizePath")(Rcpp::wrap("."))); - try { normalized_path_script = Rcpp::as(Rcpp::Function("normalizePath")(input_script)); @@ -52,22 +59,20 @@ int main(int argc, char **argv) { return EXIT_FAILURE; } - std::string output_file; + // if output file is not specified, use the input file name + if (output_file.empty()) { + std::string curr_path = + Rcpp::as(Rcpp::Function("normalizePath")(Rcpp::wrap("."))); - // MDL: some test to understand - // std::string output_ext = ".rds"; - // if (cmdl["q"]) output_ext = ".qs"; - // std::cout << "Ouptut ext: " << output_ext << " ; infile substr: " - // << in_file_name.substr(0, in_file_name.find_last_of('.')) << std::endl; + output_file = curr_path + "/" + + in_file_name.substr(0, in_file_name.find_last_of('.')); + } - // cmdl({"-o", "--output"}, - // curr_path + "/" + - // in_file_name.substr(0, in_file_name.find_last_of('.')) + ".qs") >> output_file; + // append the correct file extension + output_file += asRDS ? ".rds" : ".qs"; - cmdl({"-o", "--output"}) >> output_file; - - - if (cmdl[{"-s", "--setwd"}]) { + // set working directory to the directory of the input script + if (setwd) { const std::string dir_path = Rcpp::as( Rcpp::Function("dirname")(normalized_path_script)); diff --git a/src/poet.cpp b/src/poet.cpp index 4d15a8371..9fbf94c18 100644 --- a/src/poet.cpp +++ b/src/poet.cpp @@ -23,6 +23,7 @@ #include "Base/Macros.hpp" #include "Base/RInsidePOET.hpp" +#include "CLI/CLI.hpp" #include "Chemistry/ChemistryModule.hpp" #include "DataStructures/Field.hpp" #include "Init/InitialList.hpp" @@ -39,7 +40,7 @@ #include #include -#include "Base/argh.hpp" +#include #include #include @@ -86,82 +87,95 @@ static void init_global_functions(RInside &R) { enum ParseRet { PARSER_OK, PARSER_ERROR, PARSER_HELP }; -ParseRet parseInitValues(char **argv, RuntimeParameters ¶ms) { - // initialize argh object - argh::parser cmdl(argv); +int parseInitValues(int argc, char **argv, RuntimeParameters ¶ms) { - // if user asked for help - if (cmdl[{"help", "h"}]) { - if (MY_RANK == 0) { - MSG("Todo"); - MSG("See README.md for further information."); - } + CLI::App app{"POET - Potsdam rEactive Transport simulator"}; - return ParseRet::PARSER_HELP; - } - // if positional arguments are missing - if (!cmdl(3)) { - if (MY_RANK == 0) { - ERRMSG("POET needs 3 positional arguments: "); - ERRMSG("1) the R script defining your simulation runtime."); - ERRMSG("2) the initial .rds file generated by poet_init."); - ERRMSG("3) the directory prefix where to save results and profiling"); - } - return ParseRet::PARSER_ERROR; - } - - // parse flags and check for unknown - for (const auto &option : cmdl.flags()) { - if (!(flaglist.find(option) != flaglist.end())) { - if (MY_RANK == 0) { - ERRMSG("Unrecognized option: " + option); - ERRMSG("Make sure to use available options. Exiting!"); - } - return ParseRet::PARSER_ERROR; - } - } - - // parse parameters and check for unknown - for (const auto &option : cmdl.params()) { - if (!(paramlist.find(option.first) != paramlist.end())) { - if (MY_RANK == 0) { - ERRMSG("Unrecognized option: " + option.first); - ERRMSG("Make sure to use available options. Exiting!"); - } - return ParseRet::PARSER_ERROR; - } - } - - params.print_progressbar = cmdl[{"P", "progress"}]; + app.add_flag("-P,--progress", params.print_progress, + "Print progress bar during chemical simulation"); /*Parse work package size*/ - cmdl("work-package-size", CHEM_DEFAULT_WORK_PACKAGE_SIZE) >> - params.work_package_size; + app.add_option( + "-w,--work-package-size", params.work_package_size, + "Work package size to distribute to each worker for chemistry module") + ->check(CLI::PositiveNumber) + ->default_val(RuntimeParameters::WORK_PACKAGE_SIZE_DEFAULT); /* Parse DHT arguments */ - params.use_dht = cmdl["dht"]; - params.use_interp = cmdl["interp"]; + auto *dht_group = app.add_option_group("DHT", "DHT related options"); + + dht_group->add_flag("--dht", params.use_dht, "Enable DHT"); + // cout << "CPP: DHT is " << ( dht_enabled ? "ON" : "OFF" ) << '\n'; - cmdl("dht-size", CHEM_DHT_SIZE_PER_PROCESS_MB) >> params.dht_size; + dht_group + ->add_option("--dht-size", params.dht_size, + "DHT size per process in Megabyte") + ->check(CLI::PositiveNumber) + ->default_val(RuntimeParameters::DHT_SIZE_DEFAULT); // cout << "CPP: DHT size per process (Byte) = " << dht_size_per_process << // endl; - cmdl("dht-snaps", 0) >> params.dht_snaps; + dht_group->add_option( + "--dht-snaps", params.dht_snaps, + "Save snapshots of DHT to disk: \n0 = disabled (default)\n1 = After " + "simulation\n2 = After each iteration"); - params.use_interp = cmdl["interp"]; - cmdl("interp-size", 100) >> params.interp_size; - cmdl("interp-min", 5) >> params.interp_min_entries; - cmdl("interp-bucket-entries", 20) >> params.interp_bucket_entries; + auto *interp_group = + app.add_option_group("Interpolation", "Interpolation related options"); - params.use_ai_surrogate = cmdl["ai-surrogate"]; + interp_group->add_flag("--interp", params.use_interp, "Enable interpolation"); + interp_group + ->add_option("--interp-size", params.interp_size, + "Size of the interpolation table in Megabyte") + ->check(CLI::PositiveNumber) + ->default_val(RuntimeParameters::INTERP_SIZE_DEFAULT); + interp_group + ->add_option("--interp-min", params.interp_min_entries, + "Minimum number of entries in the interpolation table") + ->check(CLI::PositiveNumber) + ->default_val(RuntimeParameters::INTERP_MIN_ENTRIES_DEFAULT); + interp_group + ->add_option( + "--interp-bucket-entries", params.interp_bucket_entries, + "Maximum number of entries in each bucket of the interpolation table") + ->check(CLI::PositiveNumber) + ->default_val(RuntimeParameters::INTERP_BUCKET_ENTRIES_DEFAULT); - // MDL: optional flag "--qs" to switch to qsave() - params.out_ext = "rds"; - if (cmdl["qs"]) { - params.out_ext = "qs"; + app.add_flag("--ai-surrogate", params.use_ai_surrogate, + "Enable AI surrogate for chemistry module"); + + app.add_flag("--rds", params.as_rds, + "Save output as .rds file instead of .qs"); + + std::string init_file; + std::string runtime_file; + + app.add_option("runtime_file", runtime_file, + "Runtime R script defining the simulation") + ->required() + ->check(CLI::ExistingFile); + + app.add_option( + "init_file", init_file, + "Initial R script defining the simulation, produced by poet_init") + ->required() + ->check(CLI::ExistingFile); + + app.add_option("out_dir", params.out_dir, + "Output directory of the simulation") + ->required(); + + try { + app.parse(argc, argv); + } catch (const CLI::ParseError &e) { + app.exit(e); + return -1; } + // set the output extension + params.out_ext = params.as_rds ? "rds" : "qs"; + if (MY_RANK == 0) { // MSG("Complete results storage is " + BOOL_PRINT(simparams.store_result)); MSG("Output format/extension is " + params.out_ext); @@ -191,14 +205,6 @@ ParseRet parseInitValues(char **argv, RuntimeParameters ¶ms) { std::to_string(params.interp_bucket_entries)); } } - - std::string init_file; - std::string runtime_file; - - cmdl(1) >> runtime_file; - cmdl(2) >> init_file; - cmdl(3) >> params.out_dir; - // chem_params.dht_outdir = out_dir; /* distribute information to R runtime */ @@ -232,7 +238,7 @@ ParseRet parseInitValues(char **argv, RuntimeParameters ¶ms) { (*global_rt_setup)["out_ext"] = params.out_ext; params.timesteps = - Rcpp::as>(global_rt_setup->operator[]("timesteps")); + Rcpp::as>(global_rt_setup->operator[]("timesteps")); } catch (const std::exception &e) { ERRMSG("Error while parsing R scripts: " + std::string(e.what())); @@ -269,7 +275,7 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms, * master for the following loop) */ uint32_t maxiter = params.timesteps.size(); - if (params.print_progressbar) { + if (params.print_progress) { chem.setProgressBarPrintout(true); } R["TMP_PROPS"] = Rcpp::wrap(chem.getField().GetProps()); @@ -297,9 +303,10 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms, double ai_start_t = MPI_Wtime(); // Save current values from the tug field as predictor for the ai step R["TMP"] = Rcpp::wrap(chem.getField().AsVector()); - R.parseEval(std::string( - "predictors <- setNames(data.frame(matrix(TMP, nrow=" + - std::to_string(chem.getField().GetRequestedVecSize()) + ")), TMP_PROPS)")); + R.parseEval( + std::string("predictors <- setNames(data.frame(matrix(TMP, nrow=" + + std::to_string(chem.getField().GetRequestedVecSize()) + + ")), TMP_PROPS)")); R.parseEval("predictors <- predictors[ai_surrogate_species]"); // Apply preprocessing @@ -308,7 +315,8 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms, // Predict MSG("AI Predict"); - R.parseEval("aipreds_scaled <- prediction_step(model, predictors_scaled)"); + R.parseEval( + "aipreds_scaled <- prediction_step(model, predictors_scaled)"); // Apply postprocessing MSG("AI Postprocesing"); @@ -316,20 +324,22 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms, // Validate prediction and write valid predictions to chem field MSG("AI Validate"); - R.parseEval("validity_vector <- validate_predictions(predictors, aipreds)"); + R.parseEval( + "validity_vector <- validate_predictions(predictors, aipreds)"); MSG("AI Marking accepted"); chem.set_ai_surrogate_validity_vector(R.parseEval("validity_vector")); MSG("AI TempField"); - std::vector> RTempField = R.parseEval("set_valid_predictions(predictors,\ + std::vector> RTempField = + R.parseEval("set_valid_predictions(predictors,\ aipreds,\ validity_vector)"); MSG("AI Set Field"); - Field predictions_field = Field(R.parseEval("nrow(predictors)"), - RTempField, - R.parseEval("colnames(predictors)")); + Field predictions_field = + Field(R.parseEval("nrow(predictors)"), RTempField, + R.parseEval("colnames(predictors)")); MSG("AI Update"); chem.getField().update(predictions_field); @@ -344,16 +354,18 @@ static Rcpp::List RunMasterLoop(RInsidePOET &R, const RuntimeParameters ¶ms, double ai_start_t = MPI_Wtime(); R["TMP"] = Rcpp::wrap(chem.getField().AsVector()); - R.parseEval(std::string( - "targets <- setNames(data.frame(matrix(TMP, nrow=" + - std::to_string(chem.getField().GetRequestedVecSize()) + ")), TMP_PROPS)")); + R.parseEval( + std::string("targets <- setNames(data.frame(matrix(TMP, nrow=" + + std::to_string(chem.getField().GetRequestedVecSize()) + + ")), TMP_PROPS)")); R.parseEval("targets <- targets[ai_surrogate_species]"); // TODO: Check how to get the correct columns R.parseEval("target_scaled <- preprocess(targets)"); MSG("AI: incremental training"); - R.parseEval("model <- training_step(model, predictors_scaled, target_scaled, validity_vector)"); + R.parseEval("model <- training_step(model, predictors_scaled, " + "target_scaled, validity_vector)"); double ai_end_t = MPI_Wtime(); R["ai_training_time"] = ai_end_t - ai_start_t; } @@ -462,7 +474,6 @@ std::vector getSpeciesNames(const Field &&field, int root, return species_names_out; } - int main(int argc, char *argv[]) { int world_size; @@ -478,20 +489,24 @@ int main(int argc, char *argv[]) { MSG("Running POET version " + std::string(poet_version)); } - init_global_functions(R); RuntimeParameters run_params; - switch (parseInitValues(argv, run_params)) { - case ParseRet::PARSER_ERROR: - case ParseRet::PARSER_HELP: + if (parseInitValues(argc, argv, run_params) != 0) { MPI_Finalize(); return 0; - case ParseRet::PARSER_OK: - break; } + // switch (parseInitValues(argc, argv, run_params)) { + // case ParseRet::PARSER_ERROR: + // case ParseRet::PARSER_HELP: + // MPI_Finalize(); + // return 0; + // case ParseRet::PARSER_OK: + // break; + // } + InitialList init_list(R); init_list.importList(run_params.init_params, MY_RANK != 0); @@ -513,8 +528,7 @@ int main(int argc, char *argv[]) { run_params.interp_bucket_entries, run_params.interp_size, run_params.interp_min_entries, - run_params.use_ai_surrogate - }; + run_params.use_ai_surrogate}; chemistry.masterEnableSurrogates(surr_setup); @@ -538,15 +552,17 @@ int main(int argc, char *argv[]) { /* Incorporate ai surrogate from R */ R.parseEvalQ(ai_surrogate_r_library); /* Use dht species for model input and output */ - R["ai_surrogate_species"] = init_list.getChemistryInit().dht_species.getNames(); + R["ai_surrogate_species"] = + init_list.getChemistryInit().dht_species.getNames(); - const std::string ai_surrogate_input_script = init_list.getChemistryInit().ai_surrogate_input_script; + const std::string ai_surrogate_input_script = + init_list.getChemistryInit().ai_surrogate_input_script; - MSG("AI: sourcing user-provided script"); - R.parseEvalQ(ai_surrogate_input_script); + MSG("AI: sourcing user-provided script"); + R.parseEvalQ(ai_surrogate_input_script); MSG("AI: initialize AI model"); - R.parseEval("model <- initiate_model()"); + R.parseEval("model <- initiate_model()"); R.parseEval("gpu_info()"); } @@ -568,8 +584,8 @@ int main(int argc, char *argv[]) { R["setup$out_ext"] = run_params.out_ext; string r_vis_code; - r_vis_code = - "SaveRObj(x = profiling, path = paste0(out_dir, '/timings.', setup$out_ext));"; + r_vis_code = "SaveRObj(x = profiling, path = paste0(out_dir, " + "'/timings.', setup$out_ext));"; R.parseEval(r_vis_code); MSG("Done! Results are stored as R objects into <" + run_params.out_dir + diff --git a/src/poet.hpp.in b/src/poet.hpp.in index 660a9e074..0e2409f87 100644 --- a/src/poet.hpp.in +++ b/src/poet.hpp.in @@ -20,7 +20,7 @@ ** Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. */ -#pragma once +#pragma once #include #include @@ -35,63 +35,37 @@ static const char *poet_version = "@POET_VERSION@"; static const inline std::string kin_r_library = R"(@R_KIN_LIB@)"; static const inline std::string init_r_library = R"(@R_INIT_LIB@)"; -static const inline std::string ai_surrogate_r_library = R"(@R_AI_SURROGATE_LIB@)"; +static const inline std::string ai_surrogate_r_library = + R"(@R_AI_SURROGATE_LIB@)"; static const inline std::string r_runtime_parameters = "mysetup"; -const std::set flaglist{"ignore-result", "dht", "P", "progress", - "interp", "ai-surrogate", "qs"}; -const std::set paramlist{ - "work-package-size", "dht-strategy", "dht-size", "dht-snaps", - "dht-file", "interp-size", "interp-min", "interp-bucket-entries"}; - -constexpr uint32_t CHEM_DEFAULT_WORK_PACKAGE_SIZE = 32; - -constexpr uint32_t CHEM_DHT_SIZE_PER_PROCESS_MB = 1.5E3; - struct RuntimeParameters { std::string out_dir; std::vector timesteps; - std::string out_ext; // MDL added to accomodate for qs::qsave/qread - - bool print_progressbar; - uint32_t work_package_size; Rcpp::List init_params; - bool use_dht; + bool as_rds = false; + std::string out_ext; // MDL added to accomodate for qs::qsave/qread + + bool print_progress = false; + + static constexpr std::uint32_t WORK_PACKAGE_SIZE_DEFAULT = 32; + std::uint32_t work_package_size; + + bool use_dht = false; + static constexpr std::uint32_t DHT_SIZE_DEFAULT = 1.5E3; std::uint32_t dht_size; + static constexpr std::uint8_t DHT_SNAPS_DEFAULT = 0; std::uint8_t dht_snaps; - bool use_interp; + bool use_interp = false; + static constexpr std::uint32_t INTERP_SIZE_DEFAULT = 100; std::uint32_t interp_size; + static constexpr std::uint32_t INTERP_MIN_ENTRIES_DEFAULT = 5; std::uint32_t interp_min_entries; + static constexpr std::uint32_t INTERP_BUCKET_ENTRIES_DEFAULT = 20; std::uint32_t interp_bucket_entries; - bool use_ai_surrogate; - struct ChemistryParams { - // std::string database_path; - // std::string input_script; - - // bool use_dht; - // std::uint64_t dht_size; - // int dht_snaps; - // std::string dht_file; - // std::string dht_outdir; - // NamedVector dht_signifs; - - // bool use_interp; - // std::uint64_t pht_size; - // std::uint32_t pht_max_entries; - // std::uint32_t interp_min_entries; - // NamedVector pht_signifs; - - // struct Chem_Hook_Functions { - // RHookFunction dht_fill; - // RHookFunction> dht_fuzz; - // RHookFunction> interp_pre; - // RHookFunction interp_post; - // } hooks; - - // void initFromR(RInsidePOET &R); - }; + bool use_ai_surrogate = false; };