Skip to content
Snippets Groups Projects
macaon_train.cpp 4.97 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include <boost/program_options.hpp>
    #include <filesystem>
    #include "util.hpp"
    #include "Trainer.hpp"
    #include "Decoder.hpp"
    
    #include "NeuralNetwork.hpp"
    
    
    namespace po = boost::program_options;
    
    po::options_description getOptionsDescription()
    {
      po::options_description desc("Command-Line Arguments ");
    
      po::options_description req("Required");
      req.add_options()
        ("model", po::value<std::string>()->required(),
          "Directory containing the machine file to train")
    
        ("mcd", po::value<std::string>()->required(),
          "Multi Column Description file that describes the input format")
    
        ("trainTSV", po::value<std::string>()->required(),
          "TSV file of the training corpus, in CONLLU format");
    
      po::options_description opt("Optional");
      opt.add_options()
    
    Franck Dary's avatar
    Franck Dary committed
        ("debug,d", "Print debuging infos on stderr")
    
    Franck Dary's avatar
    Franck Dary committed
        ("silent", "Don't print speed and progress")
    
        ("devScore", "Compute score on dev instead of loss (slower)")
    
        ("trainTXT", po::value<std::string>()->default_value(""),
          "Raw text file of the training corpus")
        ("devTSV", po::value<std::string>()->default_value(""),
          "TSV file of the development corpus, in CONLLU format")
        ("devTXT", po::value<std::string>()->default_value(""),
          "Raw text file of the development corpus")
        ("nbEpochs,n", po::value<int>()->default_value(5),
          "Number of training epochs")
        ("help,h", "Produce this help message");
    
      desc.add(req).add(opt);
    
      return desc;
    }
    
    po::variables_map checkOptions(po::options_description & od, int argc, char ** argv)
    {
      po::variables_map vm;
    
      try {po::store(po::parse_command_line(argc, argv, od), vm);}
      catch(std::exception & e) {util::myThrow(e.what());}
    
      if (vm.count("help"))
      {
        std::stringstream ss;
        ss << od;
        fmt::print(stderr, "{}\n", ss.str());
        exit(0);
      }
    
      try {po::notify(vm);}
      catch(std::exception& e) {util::myThrow(e.what());}
    
      return vm;
    }
    
    int main(int argc, char * argv[])
    {
      auto od = getOptionsDescription();
      auto variables = checkOptions(od, argc, argv);
    
      std::filesystem::path modelPath(variables["model"].as<std::string>());
      auto machinePath = modelPath / "machine.rm";
      auto mcdFile = variables["mcd"].as<std::string>();
      auto trainTsvFile = variables["trainTSV"].as<std::string>();
      auto trainRawFile = variables["trainTXT"].as<std::string>();
      auto devTsvFile = variables["devTSV"].as<std::string>();
      auto devRawFile = variables["devTXT"].as<std::string>();
      auto nbEpoch = variables["nbEpochs"].as<int>();
    
    Franck Dary's avatar
    Franck Dary committed
      bool debug = variables.count("debug") == 0 ? false : true;
    
    Franck Dary's avatar
    Franck Dary committed
      bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
    
      bool computeDevScore = variables.count("devScore") == 0 ? false : true;
    
      fmt::print(stderr, "Training using device : {}\n", NeuralNetworkImpl::device.str());
    
    
      ReadingMachine machine(machinePath.string());
    
      BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
    
      BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile);
    
      SubConfig config(goldConfig);
    
      Trainer trainer(machine);
    
    Franck Dary's avatar
    Franck Dary committed
      trainer.createDataset(config, debug);
    
      if (!computeDevScore)
      {
        SubConfig devConfig(devGoldConfig);
        trainer.createDevDataset(devConfig, debug);
      }
    
    
      Decoder decoder(machine);
    
    
      float bestDevScore = computeDevScore ? 0 : 100;
    
      for (int i = 0; i < nbEpoch; i++)
      {
    
    Franck Dary's avatar
    Franck Dary committed
        float loss = trainer.epoch(printAdvancement);
    
        machine.getStrategy().reset();
    
    Franck Dary's avatar
    Franck Dary committed
        if (debug)
          fmt::print(stderr, "Decoding dev :\n");
    
        std::vector<std::pair<float,std::string>> devScores;
        if (computeDevScore)
        {
          auto devConfig = devGoldConfig;
          decoder.decode(devConfig, 1, debug, printAdvancement);
          machine.getStrategy().reset();
          decoder.evaluate(devConfig, modelPath, devTsvFile);
          devScores = decoder.getF1Scores(machine.getPredicted());
        }
        else
        {
          float devLoss = trainer.evalOnDev(printAdvancement);
          devScores.emplace_back(std::make_pair(devLoss, "Loss"));
        }
    
    
        std::string devScoresStr = "";
        float devScoreMean = 0;
        for (auto & score : devScores)
        {
    
          if (computeDevScore)
            devScoresStr += fmt::format("{}({:5.2f}{}),", score.second, score.first, computeDevScore ? "%" : "");
          else
            devScoresStr += fmt::format("{}({:6.1f}{}),", score.second, score.first, computeDevScore ? "%" : "");
    
          devScoreMean += score.first;
    
        }
        if (!devScoresStr.empty())
          devScoresStr.pop_back();
        devScoreMean /= devScores.size();
        bool saved = devScoreMean > bestDevScore;
    
        if (!computeDevScore)
          saved = devScoreMean < bestDevScore;
    
    Franck Dary's avatar
    Franck Dary committed
        if (saved)
        {
    
    Franck Dary's avatar
    Franck Dary committed
          machine.save();
        }
    
    Franck Dary's avatar
    Franck Dary committed
        if (debug)
    
          fmt::print(stderr, "Epoch {:^5} loss = {:6.1f} dev = {} {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
    
    Franck Dary's avatar
    Franck Dary committed
        else
    
          fmt::print(stderr, "\r{:80}\rEpoch {:^5} loss = {:6.1f} dev = {} {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
    
    Franck Dary's avatar
    Franck Dary committed
      }
      catch(std::exception & e) {util::error(e);}