Skip to content
Snippets Groups Projects
macaon_train.cpp 3.92 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include <boost/program_options.hpp>
    #include <filesystem>
    #include "util.hpp"
    #include "Trainer.hpp"
    #include "Decoder.hpp"
    
    namespace po = boost::program_options;
    
    po::options_description getOptionsDescription()
    {
      po::options_description desc("Command-Line Arguments ");
    
      po::options_description req("Required");
      req.add_options()
        ("model", po::value<std::string>()->required(),
          "Directory containing the machine file to train")
    
        ("mcd", po::value<std::string>()->required(),
          "Multi Column Description file that describes the input format")
    
        ("trainTSV", po::value<std::string>()->required(),
          "TSV file of the training corpus, in CONLLU format");
    
      po::options_description opt("Optional");
      opt.add_options()
    
    Franck Dary's avatar
    Franck Dary committed
        ("debug,d", "Print debuging infos on stderr")
    
        ("trainTXT", po::value<std::string>()->default_value(""),
          "Raw text file of the training corpus")
        ("devTSV", po::value<std::string>()->default_value(""),
          "TSV file of the development corpus, in CONLLU format")
        ("devTXT", po::value<std::string>()->default_value(""),
          "Raw text file of the development corpus")
        ("nbEpochs,n", po::value<int>()->default_value(5),
          "Number of training epochs")
        ("help,h", "Produce this help message");
    
      desc.add(req).add(opt);
    
      return desc;
    }
    
    po::variables_map checkOptions(po::options_description & od, int argc, char ** argv)
    {
      po::variables_map vm;
    
      try {po::store(po::parse_command_line(argc, argv, od), vm);}
      catch(std::exception & e) {util::myThrow(e.what());}
    
      if (vm.count("help"))
      {
        std::stringstream ss;
        ss << od;
        fmt::print(stderr, "{}\n", ss.str());
        exit(0);
      }
    
      try {po::notify(vm);}
      catch(std::exception& e) {util::myThrow(e.what());}
    
      return vm;
    }
    
    int main(int argc, char * argv[])
    {
      auto od = getOptionsDescription();
      auto variables = checkOptions(od, argc, argv);
    
      std::filesystem::path modelPath(variables["model"].as<std::string>());
      auto machinePath = modelPath / "machine.rm";
      auto mcdFile = variables["mcd"].as<std::string>();
      auto trainTsvFile = variables["trainTSV"].as<std::string>();
      auto trainRawFile = variables["trainTXT"].as<std::string>();
      auto devTsvFile = variables["devTSV"].as<std::string>();
      auto devRawFile = variables["devTXT"].as<std::string>();
      auto nbEpoch = variables["nbEpochs"].as<int>();
    
    Franck Dary's avatar
    Franck Dary committed
      bool debug = variables.count("debug") == 0 ? false : true;
    
      ReadingMachine machine(machinePath.string());
    
      BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
      SubConfig config(goldConfig);
    
      Trainer trainer(machine);
    
    Franck Dary's avatar
    Franck Dary committed
      trainer.createDataset(config, debug);
    
    
      Decoder decoder(machine);
      BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile);
    
    
    Franck Dary's avatar
    Franck Dary committed
      float bestDevScore = 0;
    
    
      for (int i = 0; i < nbEpoch; i++)
      {
    
    Franck Dary's avatar
    Franck Dary committed
        float loss = trainer.epoch(!debug);
    
        auto devConfig = devGoldConfig;
    
    Franck Dary's avatar
    Franck Dary committed
        if (debug)
          fmt::print(stderr, "Decoding dev :\n");
        else
          fmt::print(stderr, "\r{:80}\rDecoding dev...", " ");
        decoder.decode(devConfig, 1, debug);
    
        decoder.evaluate(devConfig, modelPath, devTsvFile);
    
        std::vector<float> devScores = decoder.getF1Scores(machine.getPredicted());
        std::string devScoresStr = "";
        float devScoreMean = 0;
        for (auto & score : devScores)
        {
          devScoresStr += fmt::format("{:5.2f}%,", score);
          devScoreMean += score;
        }
        if (!devScoresStr.empty())
          devScoresStr.pop_back();
        devScoreMean /= devScores.size();
        bool saved = devScoreMean > bestDevScore;
    
    Franck Dary's avatar
    Franck Dary committed
        if (saved)
        {
    
    Franck Dary's avatar
    Franck Dary committed
          machine.save();
        }
    
    Franck Dary's avatar
    Franck Dary committed
        if (debug)
    
          fmt::print(stderr, "Epoch {:^9} loss = {:7.2f} dev = {} {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
    
    Franck Dary's avatar
    Franck Dary committed
        else
    
          fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {} {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
    
    Franck Dary's avatar
    Franck Dary committed
      }
      catch(std::exception & e) {util::error(e);}