diff --git a/decoder/CMakeLists.txt b/decoder/CMakeLists.txt index f77044dc910294ee531b4db1698b2044947ea522..3b73f807d7a85442492cecb21c2036c04715b45f 100644 --- a/decoder/CMakeLists.txt +++ b/decoder/CMakeLists.txt @@ -3,3 +3,8 @@ FILE(GLOB SOURCES src/*.cpp) add_library(decoder STATIC ${SOURCES}) target_link_libraries(decoder reading_machine) +add_executable(macaon_decode src/macaon_decode.cpp) +target_link_libraries(macaon_decode Boost) +target_link_libraries(macaon_decode decoder) +target_link_libraries(macaon_decode common) +install(TARGETS macaon_decode DESTINATION bin) diff --git a/decoder/src/macaon_decode.cpp b/decoder/src/macaon_decode.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c3c3cbf58786389c27342c7bef7f3743b9580730 --- /dev/null +++ b/decoder/src/macaon_decode.cpp @@ -0,0 +1,83 @@ +#include <boost/program_options.hpp> +#include <filesystem> +#include "util.hpp" +#include "Decoder.hpp" + +namespace po = boost::program_options; + +po::options_description getOptionsDescription() +{ + po::options_description desc("Command-Line Arguments "); + + po::options_description req("Required"); + req.add_options() + ("model", po::value<std::string>()->required(), + "Directory containing the trained machine used to decode") + ("inputTSV", po::value<std::string>(), + "File containing the text to decode, TSV file") + ("inputTXT", po::value<std::string>(), + "File containing the text to decode, raw text file") + ("mcd", po::value<std::string>()->required(), + "Multi Column Description file that describes the input/output format"); + + po::options_description opt("Optional"); + opt.add_options() + ("help,h", "Produce this help message"); + + desc.add(req).add(opt); + + return desc; +} + +po::variables_map checkOptions(po::options_description & od, int argc, char ** argv) +{ + po::variables_map vm; + + try {po::store(po::parse_command_line(argc, argv, od), vm);} + catch(std::exception & e) {util::myThrow(e.what());} + + if (vm.count("help")) + { + std::stringstream ss; + ss << od; + fmt::print(stderr, "{}\n", ss.str()); + exit(0); + } + + try {po::notify(vm);} + catch(std::exception& e) {util::myThrow(e.what());} + + if (vm.count("inputTSV") + vm.count("inputTXT") != 1) + { + std::stringstream ss; + ss << od; + fmt::print(stderr, "Error : one and only one input format must be specified.\n{}\n", ss.str()); + exit(1); + } + + return vm; +} + +int main(int argc, char * argv[]) +{ + auto od = getOptionsDescription(); + auto variables = checkOptions(od, argc, argv); + + std::filesystem::path modelPath(variables["model"].as<std::string>()); + auto machinePath = modelPath / ReadingMachine::defaultMachineName; + auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : ""; + auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : ""; + auto mcdFile = variables["mcd"].as<std::string>(); + + ReadingMachine machine(machinePath.string()); + Decoder decoder(machine); + + BaseConfig config(mcdFile, inputTSV, inputTXT); + + decoder.decode(config, 1); + + config.print(stdout); + + return 0; +} + diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index 41cb82617311e0dfc7448d3c8bb09686257ee045..8db2de13e37832beb7e189ed9e7ba87753d79295 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -1,6 +1,7 @@ #ifndef READING_MACHINE__H #define READING_MACHINE__H +#include <filesystem> #include <memory> #include "Classifier.hpp" #include "Strategy.hpp" @@ -9,9 +10,16 @@ class ReadingMachine { + public : + + static inline const std::string defaultMachineName = "machine.rm"; + static inline const std::string defaultModelName = "{}.pt"; + static inline const std::string defaultDictName = "{}.dict"; + private : std::string name; + std::filesystem::path path; std::unique_ptr<Classifier> classifier; std::unique_ptr<Strategy> strategy; std::unique_ptr<FeatureFunction> featureFunction; @@ -19,7 +27,8 @@ class ReadingMachine public : - ReadingMachine(const std::string & filename); + ReadingMachine(std::filesystem::path path); + ReadingMachine(const std::string & filename, const std::vector<std::string> & models, const std::vector<std::string> & dicts); TransitionSet & getTransitionSet(); Strategy & getStrategy(); Dict & getDict(const std::string & state); diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 334f8bf327acec84adb9091c58f55fc2c0d05d4b..d1d9da646df70544fce596312807e11dc4f2db54 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -1,11 +1,11 @@ #include "ReadingMachine.hpp" #include "util.hpp" -ReadingMachine::ReadingMachine(const std::string & filename) +ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path) { dicts.emplace(std::make_pair("", Dict::State::Open)); - std::FILE * file = std::fopen(filename.c_str(), "r"); + std::FILE * file = std::fopen(path.c_str(), "r"); char buffer[1024]; std::string fileContent; @@ -46,7 +46,12 @@ ReadingMachine::ReadingMachine(const std::string & filename) strategy.reset(new Strategy(restOfFile)); - } catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", filename, e.what()));} + } catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));} +} + +ReadingMachine::ReadingMachine(const std::string & filename, const std::vector<std::string> & models, const std::vector<std::string> & dicts) +{ + } TransitionSet & ReadingMachine::getTransitionSet()