Skip to content
Snippets Groups Projects
MacaonDecode.cpp 3.33 KiB
Newer Older
#include "MacaonDecode.hpp"
Franck Dary's avatar
Franck Dary committed
#include <filesystem>
#include "util.hpp"
#include "Decoder.hpp"

po::options_description MacaonDecode::getOptionsDescription()
Franck Dary's avatar
Franck Dary committed
{
  po::options_description desc("Command-Line Arguments ");

  po::options_description req("Required");
  req.add_options()
    ("model", po::value<std::string>()->required(),
      "Directory containing the trained machine used to decode")
    ("inputTSV", po::value<std::string>(),
      "File containing the text to decode, TSV file")
    ("inputTXT", po::value<std::string>(),
      "File containing the text to decode, raw text file")
    ("mcd", po::value<std::string>()->required(),
      "Multi Column Description file that describes the input/output format");

  po::options_description opt("Optional");
  opt.add_options()
Franck Dary's avatar
Franck Dary committed
    ("debug,d", "Print debuging infos on stderr")
Franck Dary's avatar
Franck Dary committed
    ("silent", "Don't print speed and progress")
Franck Dary's avatar
Franck Dary committed
    ("help,h", "Produce this help message");

  desc.add(req).add(opt);

  return desc;
}

po::variables_map MacaonDecode::checkOptions(po::options_description & od)
Franck Dary's avatar
Franck Dary committed
{
  po::variables_map vm;

  try {po::store(po::parse_command_line(argc, argv, od), vm);}
  catch(std::exception & e) {util::myThrow(e.what());}

  if (vm.count("help"))
  {
    std::stringstream ss;
    ss << od;
    fmt::print(stderr, "{}\n", ss.str());
    exit(0);
  }

  try {po::notify(vm);}
  catch(std::exception& e) {util::myThrow(e.what());}

  if (vm.count("inputTSV") + vm.count("inputTXT") != 1)
  {
    std::stringstream ss;
    ss << od;
    fmt::print(stderr, "Error : one and only one input format must be specified.\n{}\n", ss.str());
    exit(1);
  }

  return vm;
}

int MacaonDecode::main()
Franck Dary's avatar
Franck Dary committed
{
  auto od = getOptionsDescription();
  auto variables = checkOptions(od);
Franck Dary's avatar
Franck Dary committed

  std::filesystem::path modelPath(variables["model"].as<std::string>());
Franck Dary's avatar
Franck Dary committed
  auto machinePath = modelPath / ReadingMachine::defaultMachineFilename;
  auto dictPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultDictFilename, ""));
  auto modelPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultModelFilename, ""));
Franck Dary's avatar
Franck Dary committed
  auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : "";
  auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : "";
  auto mcdFile = variables["mcd"].as<std::string>();
Franck Dary's avatar
Franck Dary committed
  bool debug = variables.count("debug") == 0 ? false : true;
Franck Dary's avatar
Franck Dary committed
  bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
  torch::globalContext().setBenchmarkCuDNN(true);

Franck Dary's avatar
Franck Dary committed
  if (dictPaths.empty())
    util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultDictFilename, "")));
  if (modelPaths.empty())
    util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, "")));
  fmt::print(stderr, "Decoding using device : {}\n", NeuralNetworkImpl::device.str());

Franck Dary's avatar
Franck Dary committed
  try
  {
    ReadingMachine machine(machinePath, modelPaths, dictPaths);
    Decoder decoder(machine);

    BaseConfig config(mcdFile, inputTSV, inputTXT);
Franck Dary's avatar
Franck Dary committed
    decoder.decode(config, 1, debug, printAdvancement);
Franck Dary's avatar
Franck Dary committed
    config.print(stdout);
  } catch(std::exception & e) {util::error(e);}
MacaonDecode::MacaonDecode(int argc, char ** argv) : argc(argc), argv(argv)
{
}