Skip to content
Snippets Groups Projects
Select Git revision
  • master default protected
  • loss
  • producer
3 results

MacaonTrain.cpp

Blame
  • MacaonTrain.cpp 11.40 KiB
    #include "MacaonTrain.hpp"
    #include <filesystem>
    #include "util.hpp"
    #include "NeuralNetwork.hpp"
    #include "WordEmbeddings.hpp"
    
    namespace po = boost::program_options;
    
    po::options_description MacaonTrain::getOptionsDescription()
    {
      po::options_description desc("Command-Line Arguments ");
    
      po::options_description req("Required");
      req.add_options()
        ("model", po::value<std::string>()->required(),
          "Directory containing the machine file to train")
        ("trainTSV", po::value<std::string>()->required(),
          "TSV file of the training corpus, in CONLLU format");
    
      po::options_description opt("Optional");
      opt.add_options()
        ("debug,d", "Print debuging infos on stderr")
        ("silent", "Don't print speed and progress")
        ("devScore", "Compute score on dev instead of loss (slower)")
        ("mcd", po::value<std::string>()->default_value("ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL"),
          "Comma separated column names that describes the input/output format")
        ("trainTXT", po::value<std::string>()->default_value(""),
          "Raw text file of the training corpus")
        ("devTSV", po::value<std::string>()->default_value(""),
          "TSV file of the development corpus, in CONLLU format")
        ("devTXT", po::value<std::string>()->default_value(""),
          "Raw text file of the development corpus")
        ("nbEpochs,n", po::value<int>()->default_value(5),
          "Number of training epochs")
        ("batchSize", po::value<int>()->default_value(64),
          "Number of examples per batch")
        ("explorationThreshold", po::value<float>()->default_value(0.1),
          "Maximum probability difference with the best scoring transition, for a transition to be explored during dynamic extraction of dataset")
        ("machine", po::value<std::string>()->default_value(""),
          "Reading machine file content")
        ("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold,ResetParameters"),
          "Description of what should happen during training")
        ("seed", po::value<int>()->default_value(100),
          "Number of examples per batch")
        ("scaleGrad", "Scale embedding's gradient with its frequence in the minibatch")
        ("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()),
          "Max norm for the embeddings")
        ("lockPretrained", "Disable fine tuning of all pretrained word embeddings.")
        ("help,h", "Produce this help message")
        ("oracleMode", "Don't train a model, transform the corpus into a sequence of transitions.");
    
      desc.add(req).add(opt);
    
      return desc;
    }
    
    po::variables_map MacaonTrain::checkOptions(po::options_description & od)
    {
      po::variables_map vm;
    
      try {po::store(po::parse_command_line(argc, argv, od), vm);}
      catch(std::exception & e) {util::myThrow(e.what());}
    
      if (vm.count("help"))
      {
        std::stringstream ss;
        ss << od;
        fmt::print(stderr, "{}\n", ss.str());
        exit(0);
      }
    
      try {po::notify(vm);}
      catch(std::exception& e) {util::myThrow(e.what());}
    
      return vm;
    }
    
    Trainer::TrainStrategy MacaonTrain::parseTrainStrategy(std::string s)
    {
      Trainer::TrainStrategy ts;
    
      try
      {
        auto splited = util::split(s, ':');
        for (auto & ss : splited)
        {
          auto elements = util::split(ss, ',');
    
          int epoch = std::stoi(elements[0]);
    
          for (unsigned int i = 1; i < elements.size(); i++)
            ts[epoch].insert(Trainer::str2TrainAction(elements[i]));
        }
      } catch (std::exception & e) {util::myThrow(fmt::format("caught '{}' parsing '{}'", e.what(), s));}
    
      return ts;
    }
    
    template 
    <
      typename Optimizer = torch::optim::Adam,
      typename OptimizerOptions = torch::optim::AdamOptions
    >
    inline auto decay(Optimizer &optimizer, double rate) -> void
    {
      for (auto &group : optimizer.param_groups())
      {
        for (auto &param : group.params())
        {
          if (!param.grad().defined())
            continue;
    
          auto &options = static_cast<OptimizerOptions &>(group.options());
          options.lr(options.lr() * (1.0 - rate));
        }
      }
    }
    
    int MacaonTrain::main()
    {
      auto od = getOptionsDescription();
      auto variables = checkOptions(od);
    
      std::filesystem::path modelPath(variables["model"].as<std::string>());
      auto machinePath = modelPath / "machine.rm";
      auto mcd = variables["mcd"].as<std::string>();
      auto trainTsvFile = variables["trainTSV"].as<std::string>();
      auto trainRawFile = variables["trainTXT"].as<std::string>();
      auto devTsvFile = variables["devTSV"].as<std::string>();
      auto devRawFile = variables["devTXT"].as<std::string>();
      auto nbEpoch = variables["nbEpochs"].as<int>();
      auto batchSize = variables["batchSize"].as<int>();
      bool debug = variables.count("debug") == 0 ? false : true;
      bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
      bool computeDevScore = variables.count("devScore") == 0 ? false : true;
      auto machineContent = variables["machine"].as<std::string>();
      auto trainStrategyStr = variables["trainStrategy"].as<std::string>();
      auto explorationThreshold = variables["explorationThreshold"].as<float>();
      auto seed = variables["seed"].as<int>();
      auto oracleMode = variables.count("oracleMode") == 0 ? false : true;
      WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>());
      WordEmbeddingsImpl::setScaleGradByFreq(variables.count("scaleGrad") != 0);
      WordEmbeddingsImpl::setCanTrainPretrained(variables.count("lockPretrained") == 0);
    
      std::srand(seed);
      torch::manual_seed(seed);
    
      auto trainStrategy = parseTrainStrategy(trainStrategyStr);
    
      torch::globalContext().setBenchmarkCuDNN(true);
    
      if (!machineContent.empty())
      {
        std::FILE * file = fopen(machinePath.c_str(), "w");
        if (!file)
          util::error(fmt::format("can't open file '{}'\n", machinePath.c_str()));
        fmt::print(file, "{}", machineContent);
        std::fclose(file);
      }
    
      fmt::print(stderr, "[{}] Training using device : {}\n", util::getTime(), NeuralNetworkImpl::device.str());
    
      try
      {
    
      ReadingMachine machine(machinePath.string(), true);
    
      util::utf8string trainRawInput;
      if (!trainRawFile.empty())
      {
        auto input = util::readFileAsUtf8(trainRawFile, false);
        trainRawInput = input[0];
      }
      BaseConfig goldConfig(mcd, trainTsvFile, trainRawInput);
      util::utf8string devRawInput;
      if (!devRawFile.empty())
      {
        auto input = util::readFileAsUtf8(devRawFile, false);
        devRawInput = input[0];
      }
      BaseConfig devGoldConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInput);
    
      Trainer trainer(machine, batchSize);
      Decoder decoder(machine);
    
      if (oracleMode)
      {
        trainer.extractActionSequence(goldConfig);
        exit(0);
      }
    
      float bestDevScore = computeDevScore ? -std::numeric_limits<float>::max() : std::numeric_limits<float>::max();
    
      auto trainInfos = machinePath.parent_path() / "train.info";
    
      int currentEpoch = 0;
    
      if (std::filesystem::exists(trainInfos))
      {
        std::FILE * f = std::fopen(trainInfos.c_str(), "r");
        char buffer[1024];
        while (!std::feof(f))
        {
          if (buffer != std::fgets(buffer, 1024, f))
            break;
          bool saved = util::split(util::split(buffer, '\t')[0], ' ').back() == "SAVED";
          float devScoreMean = std::stof(util::split(buffer, '\t').back());
          if (saved)
            bestDevScore = devScoreMean;
          currentEpoch++;
        }
        std::fclose(f);
      }
    
      for (; currentEpoch < nbEpoch; currentEpoch++)
      {
        bool saved = false;
    
        if (trainStrategy[currentEpoch].count(Trainer::TrainAction::DeleteExamples))
        {
          for (auto & entry : std::filesystem::directory_iterator(modelPath/"examples/train"))
            if (entry.is_regular_file())
              std::filesystem::remove(entry.path());
    
          if (!computeDevScore)
            for (auto & entry : std::filesystem::directory_iterator(modelPath/"examples/dev"))
              if (entry.is_regular_file())
                std::filesystem::remove(entry.path());
        }
        if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic))
        {
          machine.setDictsState(trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic) ? Dict::State::Closed : Dict::State::Open);
          trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold);
          if (!computeDevScore)
          {
            machine.setDictsState(Dict::State::Closed);
            trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold);
          }
        }
        if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetOptimizer))
        {
          if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters))
          {
            machine.resetClassifiers();
            machine.trainMode(currentEpoch == 0);
            fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getNbParameters()));
          }
    
          machine.resetOptimizers();
        }
        if (trainStrategy[currentEpoch].count(Trainer::TrainAction::Save))
        {
          saved = true;
        }
    
        trainer.makeDataLoader(modelPath/"examples/train");
        if (!computeDevScore)
          trainer.makeDevDataLoader(modelPath/"examples/dev");
    
        float loss = trainer.epoch(printAdvancement);
        if (debug)
          fmt::print(stderr, "Decoding dev :\n");
        std::vector<std::pair<float,std::string>> devScores;
        if (computeDevScore)
        {
          BaseConfig devConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInput);
          decoder.decode(devConfig, 1, 0.0, debug, printAdvancement);
          decoder.evaluate(devConfig, modelPath, devTsvFile, machine.getPredicted());
          devScores = decoder.getF1Scores(machine.getPredicted());
        }
        else
        {
          float devLoss = trainer.evalOnDev(printAdvancement);
          devScores.emplace_back(std::make_pair(devLoss, "Loss"));
        }
    
        std::string devScoresStr = "";
        float devScoreMean = 0;
        int totalLen = 0;
        std::string toAdd;
        for (auto & score : devScores)
        {
          if (computeDevScore)
            toAdd = fmt::format("{}({}{}),", score.second, util::shrink(fmt::format("{:.2f}", std::abs(score.first)),7), score.first >= 0 ? "%" : "");
          else
            toAdd = fmt::format("{}({}),", score.second, util::shrink(fmt::format("{}", score.first),7));
          devScoreMean += score.first;
    
          devScoresStr += toAdd;
          totalLen += util::printedLength(score.second) + 3;
        }
        if (!devScoresStr.empty())
          devScoresStr.pop_back();
        devScoresStr = fmt::format("{:{}}", devScoresStr, totalLen+7*devScores.size());
        devScoreMean /= devScores.size();
    
        if (computeDevScore)
          saved = saved or devScoreMean >= bestDevScore;
        else
          saved = saved or devScoreMean <= bestDevScore;
    
        if (saved)
        {
          bestDevScore = devScoreMean;
          machine.saveBest();
        }
    
        machine.saveLast();
    
        if (printAdvancement)
          fmt::print(stderr, "\r{:80}\r", "");
        std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:7} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), util::shrink(fmt::format("{}",loss), 7), devScoresStr, saved ? "SAVED" : "");
        fmt::print(stderr, "{}\n", iterStr);
        std::FILE * f = std::fopen(trainInfos.c_str(), "a");
        fmt::print(f, "{}\t{}\n", iterStr, devScoreMean);
        std::fclose(f);
      }
    
      }
      catch(std::exception & e) {util::error(e);}
    
      return 0;
    }
    
    MacaonTrain::MacaonTrain(int argc, char ** argv) : argc(argc), argv(argv)
    {
    }