Skip to content
Snippets Groups Projects
macaon_decode.cpp 2.92 KiB
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #include <boost/program_options.hpp>
    #include <filesystem>
    #include "util.hpp"
    #include "Decoder.hpp"
    
    namespace po = boost::program_options;
    
    po::options_description getOptionsDescription()
    {
      po::options_description desc("Command-Line Arguments ");
    
      po::options_description req("Required");
      req.add_options()
        ("model", po::value<std::string>()->required(),
          "Directory containing the trained machine used to decode")
        ("inputTSV", po::value<std::string>(),
          "File containing the text to decode, TSV file")
        ("inputTXT", po::value<std::string>(),
          "File containing the text to decode, raw text file")
        ("mcd", po::value<std::string>()->required(),
          "Multi Column Description file that describes the input/output format");
    
      po::options_description opt("Optional");
      opt.add_options()
        ("help,h", "Produce this help message");
    
      desc.add(req).add(opt);
    
      return desc;
    }
    
    po::variables_map checkOptions(po::options_description & od, int argc, char ** argv)
    {
      po::variables_map vm;
    
      try {po::store(po::parse_command_line(argc, argv, od), vm);}
      catch(std::exception & e) {util::myThrow(e.what());}
    
      if (vm.count("help"))
      {
        std::stringstream ss;
        ss << od;
        fmt::print(stderr, "{}\n", ss.str());
        exit(0);
      }
    
      try {po::notify(vm);}
      catch(std::exception& e) {util::myThrow(e.what());}
    
      if (vm.count("inputTSV") + vm.count("inputTXT") != 1)
      {
        std::stringstream ss;
        ss << od;
        fmt::print(stderr, "Error : one and only one input format must be specified.\n{}\n", ss.str());
        exit(1);
      }
    
      return vm;
    }
    
    int main(int argc, char * argv[])
    {
      auto od = getOptionsDescription();
      auto variables = checkOptions(od, argc, argv);
    
      std::filesystem::path modelPath(variables["model"].as<std::string>());
    
    Franck Dary's avatar
    Franck Dary committed
      auto machinePath = modelPath / ReadingMachine::defaultMachineFilename;
      auto dictPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultDictFilename, ""));
      auto modelPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultModelFilename, ""));
    
    Franck Dary's avatar
    Franck Dary committed
      auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : "";
      auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : "";
      auto mcdFile = variables["mcd"].as<std::string>();
    
    
    Franck Dary's avatar
    Franck Dary committed
      if (dictPaths.empty())
        util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultDictFilename, "")));
      if (modelPaths.empty())
        util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, "")));
    
    Franck Dary's avatar
    Franck Dary committed
      try
      {
        ReadingMachine machine(machinePath, modelPaths, dictPaths);
        Decoder decoder(machine);
    
        BaseConfig config(mcdFile, inputTSV, inputTXT);
    
    Franck Dary's avatar
    Franck Dary committed
        decoder.decode(config, 1);
    
    Franck Dary's avatar
    Franck Dary committed
        config.print(stdout);
      } catch(std::exception & e) {util::error(e);}