Skip to content
Snippets Groups Projects
macaon_train.cpp 5.09 KiB
Newer Older
#include <boost/program_options.hpp>
#include <filesystem>
#include "util.hpp"
#include "Trainer.hpp"
#include "Decoder.hpp"
#include "NeuralNetwork.hpp"

namespace po = boost::program_options;

po::options_description getOptionsDescription()
{
  po::options_description desc("Command-Line Arguments ");

  po::options_description req("Required");
  req.add_options()
    ("model", po::value<std::string>()->required(),
      "Directory containing the machine file to train")
    ("mcd", po::value<std::string>()->required(),
      "Multi Column Description file that describes the input format")
    ("trainTSV", po::value<std::string>()->required(),
      "TSV file of the training corpus, in CONLLU format");

  po::options_description opt("Optional");
  opt.add_options()
Franck Dary's avatar
Franck Dary committed
    ("debug,d", "Print debuging infos on stderr")
Franck Dary's avatar
Franck Dary committed
    ("silent", "Don't print speed and progress")
    ("devScore", "Compute score on dev instead of loss (slower)")
    ("trainTXT", po::value<std::string>()->default_value(""),
      "Raw text file of the training corpus")
    ("devTSV", po::value<std::string>()->default_value(""),
      "TSV file of the development corpus, in CONLLU format")
    ("devTXT", po::value<std::string>()->default_value(""),
      "Raw text file of the development corpus")
    ("nbEpochs,n", po::value<int>()->default_value(5),
      "Number of training epochs")
    ("help,h", "Produce this help message");

  desc.add(req).add(opt);

  return desc;
}

po::variables_map checkOptions(po::options_description & od, int argc, char ** argv)
{
  po::variables_map vm;

  try {po::store(po::parse_command_line(argc, argv, od), vm);}
  catch(std::exception & e) {util::myThrow(e.what());}

  if (vm.count("help"))
  {
    std::stringstream ss;
    ss << od;
    fmt::print(stderr, "{}\n", ss.str());
    exit(0);
  }

  try {po::notify(vm);}
  catch(std::exception& e) {util::myThrow(e.what());}

  return vm;
}

int main(int argc, char * argv[])
{
  auto od = getOptionsDescription();
  auto variables = checkOptions(od, argc, argv);

  std::filesystem::path modelPath(variables["model"].as<std::string>());
  auto machinePath = modelPath / "machine.rm";
  auto mcdFile = variables["mcd"].as<std::string>();
  auto trainTsvFile = variables["trainTSV"].as<std::string>();
  auto trainRawFile = variables["trainTXT"].as<std::string>();
  auto devTsvFile = variables["devTSV"].as<std::string>();
  auto devRawFile = variables["devTXT"].as<std::string>();
  auto nbEpoch = variables["nbEpochs"].as<int>();
Franck Dary's avatar
Franck Dary committed
  bool debug = variables.count("debug") == 0 ? false : true;
Franck Dary's avatar
Franck Dary committed
  bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
  bool computeDevScore = variables.count("devScore") == 0 ? false : true;
  fmt::print(stderr, "Training using device : {}\n", NeuralNetworkImpl::device.str());

  ReadingMachine machine(machinePath.string());

  BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
  BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
  SubConfig config(goldConfig);

  Trainer trainer(machine);
Franck Dary's avatar
Franck Dary committed
  trainer.createDataset(config, debug);
  if (!computeDevScore)
  {
    SubConfig devConfig(devGoldConfig);
    trainer.createDevDataset(devConfig, debug);
  }

  Decoder decoder(machine);

  float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();
  for (int i = 0; i < nbEpoch; i++)
  {
Franck Dary's avatar
Franck Dary committed
    float loss = trainer.epoch(printAdvancement);
    machine.getStrategy().reset();
Franck Dary's avatar
Franck Dary committed
    if (debug)
      fmt::print(stderr, "Decoding dev :\n");
    std::vector<std::pair<float,std::string>> devScores;
    if (computeDevScore)
    {
      auto devConfig = devGoldConfig;
      decoder.decode(devConfig, 1, debug, printAdvancement);
      machine.getStrategy().reset();
      decoder.evaluate(devConfig, modelPath, devTsvFile);
      devScores = decoder.getF1Scores(machine.getPredicted());
    }
    else
    {
      float devLoss = trainer.evalOnDev(printAdvancement);
      devScores.emplace_back(std::make_pair(devLoss, "Loss"));
    }

    std::string devScoresStr = "";
    float devScoreMean = 0;
    for (auto & score : devScores)
    {
      if (computeDevScore)
        devScoresStr += fmt::format("{}({:5.2f}{}),", score.second, score.first, computeDevScore ? "%" : "");
      else
        devScoresStr += fmt::format("{}({:6.1f}{}),", score.second, score.first, computeDevScore ? "%" : "");
      devScoreMean += score.first;
    }
    if (!devScoresStr.empty())
      devScoresStr.pop_back();
    devScoreMean /= devScores.size();
    bool saved = devScoreMean > bestDevScore;
    if (!computeDevScore)
      saved = devScoreMean < bestDevScore;
Franck Dary's avatar
Franck Dary committed
    if (saved)
    {
Franck Dary's avatar
Franck Dary committed
      machine.save();
    }
Franck Dary's avatar
Franck Dary committed
    if (debug)
      fmt::print(stderr, "Epoch {:^5} loss = {:6.1f} dev = {} {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
Franck Dary's avatar
Franck Dary committed
    else
      fmt::print(stderr, "\r{:80}\rEpoch {:^5} loss = {:6.1f} dev = {} {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
Franck Dary's avatar
Franck Dary committed
  }
  catch(std::exception & e) {util::error(e);}