Skip to content
Snippets Groups Projects
macaon_decode.cpp 2.92 KiB
#include <boost/program_options.hpp>
#include <filesystem>
#include "util.hpp"
#include "Decoder.hpp"

namespace po = boost::program_options;

po::options_description getOptionsDescription()
{
  po::options_description desc("Command-Line Arguments ");

  po::options_description req("Required");
  req.add_options()
    ("model", po::value<std::string>()->required(),
      "Directory containing the trained machine used to decode")
    ("inputTSV", po::value<std::string>(),
      "File containing the text to decode, TSV file")
    ("inputTXT", po::value<std::string>(),
      "File containing the text to decode, raw text file")
    ("mcd", po::value<std::string>()->required(),
      "Multi Column Description file that describes the input/output format");

  po::options_description opt("Optional");
  opt.add_options()
    ("help,h", "Produce this help message");

  desc.add(req).add(opt);

  return desc;
}

po::variables_map checkOptions(po::options_description & od, int argc, char ** argv)
{
  po::variables_map vm;

  try {po::store(po::parse_command_line(argc, argv, od), vm);}
  catch(std::exception & e) {util::myThrow(e.what());}

  if (vm.count("help"))
  {
    std::stringstream ss;
    ss << od;
    fmt::print(stderr, "{}\n", ss.str());
    exit(0);
  }

  try {po::notify(vm);}
  catch(std::exception& e) {util::myThrow(e.what());}

  if (vm.count("inputTSV") + vm.count("inputTXT") != 1)
  {
    std::stringstream ss;
    ss << od;
    fmt::print(stderr, "Error : one and only one input format must be specified.\n{}\n", ss.str());
    exit(1);
  }

  return vm;
}

int main(int argc, char * argv[])
{
  auto od = getOptionsDescription();
  auto variables = checkOptions(od, argc, argv);

  std::filesystem::path modelPath(variables["model"].as<std::string>());
  auto machinePath = modelPath / ReadingMachine::defaultMachineFilename;
  auto dictPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultDictFilename, ""));
  auto modelPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultModelFilename, ""));
  auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : "";
  auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : "";
  auto mcdFile = variables["mcd"].as<std::string>();

  if (dictPaths.empty())
    util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultDictFilename, "")));
  if (modelPaths.empty())
    util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, "")));

  try
  {
    ReadingMachine machine(machinePath, modelPaths, dictPaths);
    Decoder decoder(machine);

    BaseConfig config(mcdFile, inputTSV, inputTXT);

    decoder.decode(config, 1);

    config.print(stdout);
  } catch(std::exception & e) {util::error(e);}

  return 0;
}