Skip to content
Snippets Groups Projects
MacaonDecode.cpp 3.44 KiB
#include "MacaonDecode.hpp"
#include <filesystem>
#include "util.hpp"
#include "Decoder.hpp"

po::options_description MacaonDecode::getOptionsDescription()
{
  po::options_description desc("Command-Line Arguments ");

  po::options_description req("Required");
  req.add_options()
    ("model", po::value<std::string>()->required(),
      "Directory containing the trained machine used to decode")
    ("inputTSV", po::value<std::string>(),
      "File containing the text to decode, TSV file")
    ("inputTXT", po::value<std::string>(),
      "File containing the text to decode, raw text file");

  po::options_description opt("Optional");
  opt.add_options()
    ("debug,d", "Print debuging infos on stderr")
    ("silent", "Don't print speed and progress")
    ("mcd", po::value<std::string>()->default_value("ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL"),
      "Comma separated column names that describes the input/output format")
    ("beamSize", po::value<int>()->default_value(1),
      "Size of the beam during beam search")
    ("beamThreshold", po::value<float>()->default_value(0.1),
      "Minimal probability an action must have to be considered in the beam search")
    ("help,h", "Produce this help message");

  desc.add(req).add(opt);

  return desc;
}

po::variables_map MacaonDecode::checkOptions(po::options_description & od)
{
  po::variables_map vm;

  try {po::store(po::parse_command_line(argc, argv, od), vm);}
  catch(std::exception & e) {util::myThrow(e.what());}

  if (vm.count("help"))
  {
    std::stringstream ss;
    ss << od;
    fmt::print(stderr, "{}\n", ss.str());
    exit(0);
  }

  try {po::notify(vm);}
  catch(std::exception& e) {util::myThrow(e.what());}

  if (vm.count("inputTSV") + vm.count("inputTXT") != 1)
  {
    std::stringstream ss;
    ss << od;
    fmt::print(stderr, "Error : one and only one input format must be specified.\n{}\n", ss.str());
    exit(1);
  }

  return vm;
}

int MacaonDecode::main()
{
  auto od = getOptionsDescription();
  auto variables = checkOptions(od);

  std::filesystem::path modelPath(variables["model"].as<std::string>());
  auto machinePath = modelPath / ReadingMachine::defaultMachineFilename;
  auto modelPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultModelFilename, ""));
  auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : "";
  auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : "";
  auto mcd = variables["mcd"].as<std::string>();
  bool debug = variables.count("debug") == 0 ? false : true;
  bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
  auto beamSize = variables["beamSize"].as<int>();
  auto beamThreshold = variables["beamThreshold"].as<float>();

  torch::globalContext().setBenchmarkCuDNN(true);

  if (modelPaths.empty())
    util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, "")));

  fmt::print(stderr, "Decoding using device : {}\n", NeuralNetworkImpl::device.str());

  try
  {
    ReadingMachine machine(machinePath, modelPaths);
    Decoder decoder(machine);

    BaseConfig config(mcd, inputTSV, inputTXT);

    decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement);

    config.print(stdout);
  } catch(std::exception & e) {util::error(e);}

  return 0;
}

MacaonDecode::MacaonDecode(int argc, char ** argv) : argc(argc), argv(argv)
{
}