Skip to content
Snippets Groups Projects
macaon_train.cpp 2.91 KiB
Newer Older
#include <boost/program_options.hpp>
#include <filesystem>
#include "util.hpp"
#include "Trainer.hpp"
#include "Decoder.hpp"

namespace po = boost::program_options;

po::options_description getOptionsDescription()
{
  po::options_description desc("Command-Line Arguments ");

  po::options_description req("Required");
  req.add_options()
    ("model", po::value<std::string>()->required(),
      "Directory containing the machine file to train")
    ("mcd", po::value<std::string>()->required(),
      "Multi Column Description file that describes the input format")
    ("trainTSV", po::value<std::string>()->required(),
      "TSV file of the training corpus, in CONLLU format");

  po::options_description opt("Optional");
  opt.add_options()
    ("trainTXT", po::value<std::string>()->default_value(""),
      "Raw text file of the training corpus")
    ("devTSV", po::value<std::string>()->default_value(""),
      "TSV file of the development corpus, in CONLLU format")
    ("devTXT", po::value<std::string>()->default_value(""),
      "Raw text file of the development corpus")
    ("nbEpochs,n", po::value<int>()->default_value(5),
      "Number of training epochs")
    ("help,h", "Produce this help message");

  desc.add(req).add(opt);

  return desc;
}

po::variables_map checkOptions(po::options_description & od, int argc, char ** argv)
{
  po::variables_map vm;

  try {po::store(po::parse_command_line(argc, argv, od), vm);}
  catch(std::exception & e) {util::myThrow(e.what());}

  if (vm.count("help"))
  {
    std::stringstream ss;
    ss << od;
    fmt::print(stderr, "{}\n", ss.str());
    exit(0);
  }

  try {po::notify(vm);}
  catch(std::exception& e) {util::myThrow(e.what());}

  return vm;
}

int main(int argc, char * argv[])
{
  auto od = getOptionsDescription();
  auto variables = checkOptions(od, argc, argv);

  std::filesystem::path modelPath(variables["model"].as<std::string>());
  auto machinePath = modelPath / "machine.rm";
  auto mcdFile = variables["mcd"].as<std::string>();
  auto trainTsvFile = variables["trainTSV"].as<std::string>();
  auto trainRawFile = variables["trainTXT"].as<std::string>();
  auto devTsvFile = variables["devTSV"].as<std::string>();
  auto devRawFile = variables["devTXT"].as<std::string>();
  auto nbEpoch = variables["nbEpochs"].as<int>();

  ReadingMachine machine(machinePath.string());

  BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
  SubConfig config(goldConfig);

  Trainer trainer(machine);
  trainer.createDataset(config);

  Decoder decoder(machine);
  BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile);

  for (int i = 0; i < nbEpoch; i++)
  {
    float loss = trainer.epoch();
    auto devConfig = devGoldConfig;
    decoder.decode(devConfig, 1);
    decoder.evaluate(devConfig, modelPath, devTsvFile);
    fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {}%\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, decoder.getF1Score("UPOS"));
  }

  return 0;
}