Skip to content
Snippets Groups Projects
Select Git revision
  • 15d915ca7b8798c14da8177c5ee72e9bdb423472
  • master default protected
  • loss
  • producer
4 results

MacaonDecode.cpp

Blame
  • MacaonDecode.cpp 3.28 KiB
    #include "MacaonDecode.hpp"
    #include <filesystem>
    #include "util.hpp"
    #include "Decoder.hpp"
    
    po::options_description MacaonDecode::getOptionsDescription()
    {
      po::options_description desc("Command-Line Arguments ");
    
      po::options_description req("Required");
      req.add_options()
        ("model", po::value<std::string>()->required(),
          "Directory containing the trained machine used to decode")
        ("inputTSV", po::value<std::string>(),
          "File containing the text to decode, TSV file")
        ("inputTXT", po::value<std::string>(),
          "File containing the text to decode, raw text file")
        ("mcd", po::value<std::string>()->required(),
          "Multi Column Description file that describes the input/output format");
    
      po::options_description opt("Optional");
      opt.add_options()
        ("debug,d", "Print debuging infos on stderr")
        ("silent", "Don't print speed and progress")
        ("help,h", "Produce this help message");
    
      desc.add(req).add(opt);
    
      return desc;
    }
    
    po::variables_map MacaonDecode::checkOptions(po::options_description & od)
    {
      po::variables_map vm;
    
      try {po::store(po::parse_command_line(argc, argv, od), vm);}
      catch(std::exception & e) {util::myThrow(e.what());}
    
      if (vm.count("help"))
      {
        std::stringstream ss;
        ss << od;
        fmt::print(stderr, "{}\n", ss.str());
        exit(0);
      }
    
      try {po::notify(vm);}
      catch(std::exception& e) {util::myThrow(e.what());}
    
      if (vm.count("inputTSV") + vm.count("inputTXT") != 1)
      {
        std::stringstream ss;
        ss << od;
        fmt::print(stderr, "Error : one and only one input format must be specified.\n{}\n", ss.str());
        exit(1);
      }
    
      return vm;
    }
    
    int MacaonDecode::main()
    {
      auto od = getOptionsDescription();
      auto variables = checkOptions(od);
    
      std::filesystem::path modelPath(variables["model"].as<std::string>());
      auto machinePath = modelPath / ReadingMachine::defaultMachineFilename;
      auto dictPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultDictFilename, ""));
      auto modelPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultModelFilename, ""));
      auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : "";
      auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : "";
      auto mcdFile = variables["mcd"].as<std::string>();
      bool debug = variables.count("debug") == 0 ? false : true;
      bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
    
      if (dictPaths.empty())
        util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultDictFilename, "")));
      if (modelPaths.empty())
        util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, "")));
    
      fmt::print(stderr, "Decoding using device : {}\n", NeuralNetworkImpl::device.str());
    
      try
      {
        ReadingMachine machine(machinePath, modelPaths, dictPaths);
        Decoder decoder(machine);
    
        BaseConfig config(mcdFile, inputTSV, inputTXT);
    
        decoder.decode(config, 1, debug, printAdvancement);
    
        config.print(stdout);
      } catch(std::exception & e) {util::error(e);}
    
      return 0;
    }
    
    MacaonDecode::MacaonDecode(int argc, char ** argv) : argc(argc), argv(argv)
    {
    }