Skip to content
Snippets Groups Projects
macaon_decode.cpp 3.27 KiB
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #include <boost/program_options.hpp>
    #include <filesystem>
    #include "util.hpp"
    #include "Decoder.hpp"
    
    namespace po = boost::program_options;
    
    po::options_description getOptionsDescription()
    {
      po::options_description desc("Command-Line Arguments ");
    
      po::options_description req("Required");
      req.add_options()
        ("model", po::value<std::string>()->required(),
          "Directory containing the trained machine used to decode")
        ("inputTSV", po::value<std::string>(),
          "File containing the text to decode, TSV file")
        ("inputTXT", po::value<std::string>(),
          "File containing the text to decode, raw text file")
        ("mcd", po::value<std::string>()->required(),
          "Multi Column Description file that describes the input/output format");
    
      po::options_description opt("Optional");
      opt.add_options()
    
    Franck Dary's avatar
    Franck Dary committed
        ("debug,d", "Print debuging infos on stderr")
    
    Franck Dary's avatar
    Franck Dary committed
        ("silent", "Don't print speed and progress")
    
    Franck Dary's avatar
    Franck Dary committed
        ("help,h", "Produce this help message");
    
      desc.add(req).add(opt);
    
      return desc;
    }
    
    po::variables_map checkOptions(po::options_description & od, int argc, char ** argv)
    {
      po::variables_map vm;
    
      try {po::store(po::parse_command_line(argc, argv, od), vm);}
      catch(std::exception & e) {util::myThrow(e.what());}
    
      if (vm.count("help"))
      {
        std::stringstream ss;
        ss << od;
        fmt::print(stderr, "{}\n", ss.str());
        exit(0);
      }
    
      try {po::notify(vm);}
      catch(std::exception& e) {util::myThrow(e.what());}
    
      if (vm.count("inputTSV") + vm.count("inputTXT") != 1)
      {
        std::stringstream ss;
        ss << od;
        fmt::print(stderr, "Error : one and only one input format must be specified.\n{}\n", ss.str());
        exit(1);
      }
    
      return vm;
    }
    
    int main(int argc, char * argv[])
    {
      auto od = getOptionsDescription();
      auto variables = checkOptions(od, argc, argv);
    
      std::filesystem::path modelPath(variables["model"].as<std::string>());
    
    Franck Dary's avatar
    Franck Dary committed
      auto machinePath = modelPath / ReadingMachine::defaultMachineFilename;
      auto dictPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultDictFilename, ""));
      auto modelPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultModelFilename, ""));
    
    Franck Dary's avatar
    Franck Dary committed
      auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : "";
      auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : "";
      auto mcdFile = variables["mcd"].as<std::string>();
    
    Franck Dary's avatar
    Franck Dary committed
      bool debug = variables.count("debug") == 0 ? false : true;
    
    Franck Dary's avatar
    Franck Dary committed
      bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
    
    Franck Dary's avatar
    Franck Dary committed
      if (dictPaths.empty())
        util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultDictFilename, "")));
      if (modelPaths.empty())
        util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, "")));
    
      fmt::print(stderr, "Decoding using device : {}\n", NeuralNetworkImpl::device.str());
    
    
    Franck Dary's avatar
    Franck Dary committed
      try
      {
        ReadingMachine machine(machinePath, modelPaths, dictPaths);
        Decoder decoder(machine);
    
        BaseConfig config(mcdFile, inputTSV, inputTXT);
    
    Franck Dary's avatar
    Franck Dary committed
        decoder.decode(config, 1, debug, printAdvancement);
    
    Franck Dary's avatar
    Franck Dary committed
        config.print(stdout);
      } catch(std::exception & e) {util::error(e);}