Skip to content
Snippets Groups Projects
Trainer.cpp 11 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#include "Trainer.hpp"
#include "SubConfig.hpp"

Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize)
Franck Dary's avatar
Franck Dary committed
{
}

Franck Dary's avatar
Franck Dary committed
void Trainer::makeDataLoader(std::filesystem::path dir)
Franck Dary's avatar
Franck Dary committed
{
  trainDataset.reset(new Dataset(dir));
  dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
Franck Dary's avatar
Franck Dary committed
void Trainer::makeDevDataLoader(std::filesystem::path dir)
{
  devDataset.reset(new Dataset(dir));
  devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
}

void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle)
  SubConfig config(goldConfig, goldConfig.getNbLines());
Franck Dary's avatar
Franck Dary committed

  machine.trainMode(false);
Franck Dary's avatar
Franck Dary committed
  extractExamples(config, debug, dir, epoch, dynamicOracle);
Franck Dary's avatar
Franck Dary committed
void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle)
  int maxNbExamplesPerFile = 50000;
  std::map<std::string, Examples> examplesPerState;
  config.addPredicted(machine.getPredicted());
  config.setStrategy(machine.getStrategyDefinition());
  config.setState(config.getStrategy().getInitialState());
  machine.getClassifier()->setState(config.getState());
Franck Dary's avatar
Franck Dary committed
  auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}.{}", epoch, dynamicOracle);
Franck Dary's avatar
Franck Dary committed
  if (std::filesystem::exists(currentEpochAllExtractedFile))
    return;

  fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");

Franck Dary's avatar
Franck Dary committed
  while (true)
  {
Franck Dary's avatar
Franck Dary committed
    if (debug)
      config.printForDebug(stderr);

    if (machine.hasSplitWordTransitionSet())
      config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
Franck Dary's avatar
Franck Dary committed
    auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config);
    config.setAppliableTransitions(appliableTransitions);
    std::vector<std::vector<long>> context;

      context = machine.getClassifier()->getNN()->extractContext(config);
    } catch(std::exception & e)
    {
      util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
    }
Franck Dary's avatar
Franck Dary committed

Franck Dary's avatar
Franck Dary committed
    Transition * goldTransition = nullptr;

    goldTransition = machine.getTransitionSet().getBestAppliableTransition(config, appliableTransitions, dynamicOracle);
Franck Dary's avatar
Franck Dary committed
    if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
    {
      auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
      auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
  
      int chosenTransition = -1;
      float bestScore = std::numeric_limits<float>::min();

      for (unsigned int i = 0; i < prediction.size(0); i++)
      {
        float score = prediction[i].item<float>();
        if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config))
        {
          chosenTransition = i;
          bestScore = score;
        }
      }

      transition = machine.getTransitionSet().getTransition(chosenTransition);
    }
    else
    {
Franck Dary's avatar
Franck Dary committed
      transition = goldTransition;
Franck Dary's avatar
Franck Dary committed
    if (!transition or !goldTransition)
    {
      config.printForDebug(stderr);
      util::myThrow("No transition appliable !");
    }

Franck Dary's avatar
Franck Dary committed
    int goldIndex = machine.getTransitionSet().getTransitionIndex(goldTransition);
Franck Dary's avatar
Franck Dary committed

    totalNbExamples += context.size();
    if (totalNbExamples >= (int)safetyNbExamplesMax)
      util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));
    examplesPerState[config.getState()].addContext(context);
    examplesPerState[config.getState()].addClass(goldIndex);
Franck Dary's avatar
Franck Dary committed
    examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle);
Franck Dary's avatar
Franck Dary committed

    transition->apply(config);
    config.addToHistory(transition->getName());

    auto movement = config.getStrategy().getMovement(config, transition->getName());
    if (debug)
      fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
Franck Dary's avatar
Franck Dary committed
    if (movement == Strategy::endMovement)
      break;

    config.setState(movement.first);
    machine.getClassifier()->setState(movement.first);
Franck Dary's avatar
Franck Dary committed
    config.moveWordIndexRelaxed(movement.second);
Franck Dary's avatar
Franck Dary committed

    if (config.needsUpdate())
      config.update();
  }
  for (auto & it : examplesPerState)
Franck Dary's avatar
Franck Dary committed
    it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle);

  std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
  if (!f)
    util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
  std::fclose(f);

  fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(totalNbExamples));
Franck Dary's avatar
Franck Dary committed
}

float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples)
Franck Dary's avatar
Franck Dary committed
{
Franck Dary's avatar
Franck Dary committed
  constexpr int printInterval = 50;
  int nbExamplesProcessed = 0;
  int totalNbExamplesProcessed = 0;
Franck Dary's avatar
Franck Dary committed
  float totalLoss = 0.0;
  float lossSoFar = 0.0;

  torch::AutoGradMode useGrad(train);
Franck Dary's avatar
Franck Dary committed
  machine.trainMode(train);
  machine.setDictsState(Dict::State::Closed);
  auto lossFct = torch::nn::CrossEntropyLoss();

  auto pastTime = std::chrono::high_resolution_clock::now();

  for (auto & batch : *loader)
Franck Dary's avatar
Franck Dary committed
  {
      machine.getClassifier()->getOptimizer().zero_grad();
Franck Dary's avatar
Franck Dary committed

    auto data = std::get<0>(batch);
    auto labels = std::get<1>(batch);
    auto state = std::get<2>(batch);

    machine.getClassifier()->setState(state);
Franck Dary's avatar
Franck Dary committed

    auto prediction = machine.getClassifier()->getNN()(data);
Franck Dary's avatar
Franck Dary committed
    if (prediction.dim() == 1)
      prediction = prediction.unsqueeze(0);
Franck Dary's avatar
Franck Dary committed

Franck Dary's avatar
Franck Dary committed
    labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0));

Franck Dary's avatar
Franck Dary committed
    auto loss = machine.getClassifier()->getLossMultiplier()*lossFct(prediction, labels);
Franck Dary's avatar
Franck Dary committed
    try
    {
      totalLoss += loss.item<float>();
      lossSoFar += loss.item<float>();
    } catch(std::exception & e) {util::myThrow(e.what());}
Franck Dary's avatar
Franck Dary committed

      machine.getClassifier()->getOptimizer().step();
Franck Dary's avatar
Franck Dary committed

    totalNbExamplesProcessed += torch::numel(labels);

Franck Dary's avatar
Franck Dary committed
    if (printAdvancement)
Franck Dary's avatar
Franck Dary committed
    {
      nbExamplesProcessed += torch::numel(labels);
Franck Dary's avatar
Franck Dary committed

      if (nbExamplesProcessed >= printInterval)
Franck Dary's avatar
Franck Dary committed
      {
        auto actualTime = std::chrono::high_resolution_clock::now();
        double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
        pastTime = actualTime;
        auto speed = (int)(nbExamplesProcessed/seconds);
        auto progression = 100.0*totalNbExamplesProcessed / nbExamples;
        auto statusStr = fmt::format("{:6.2f}% loss={:<7.3f} speed={:<6}ex/s", progression, lossSoFar, speed);
          fmt::print(stderr, "\r{:80}\rtraining : {}", "", statusStr);
          fmt::print(stderr, "\r{:80}\reval on dev : {}", "", statusStr);
Franck Dary's avatar
Franck Dary committed
        lossSoFar = 0;
        nbExamplesProcessed = 0;
Franck Dary's avatar
Franck Dary committed
      }
  return totalLoss / nbExamples;
Franck Dary's avatar
Franck Dary committed
}

float Trainer::epoch(bool printAdvancement)
{
  return processDataset(dataLoader, true, printAdvancement, trainDataset->size().value());
}

float Trainer::evalOnDev(bool printAdvancement)
{
  return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value());
Franck Dary's avatar
Franck Dary committed
void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int epoch, bool dynamicOracle)
{
  if (currentExampleIndex-lastSavedIndex < (int)threshold)
    return;
  if (contexts.empty())
    return;

  auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
Franck Dary's avatar
Franck Dary committed
  auto filename = fmt::format("{}_{}-{}.{}.{}.tensor", state, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle);
  torch::save(tensorToSave, dir/filename);
  lastSavedIndex = currentExampleIndex;
  contexts.clear();
  classes.clear();
}

void Trainer::Examples::addContext(std::vector<std::vector<long>> & context)
{
  for (auto & element : context)
    contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone());

  currentExampleIndex += context.size();
}

void Trainer::Examples::addClass(int goldIndex)
{
    auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
    gold[0] = goldIndex;

    while (classes.size() < contexts.size())
      classes.emplace_back(gold);
}

Franck Dary's avatar
Franck Dary committed
void Trainer::fillDicts(SubConfig & config, bool debug)
{
  torch::AutoGradMode useGrad(false);

  config.addPredicted(machine.getPredicted());
  config.setStrategy(machine.getStrategyDefinition());
  config.setState(config.getStrategy().getInitialState());
  machine.getClassifier()->setState(config.getState());
Franck Dary's avatar
Franck Dary committed
    if (debug)
      config.printForDebug(stderr);

    if (machine.hasSplitWordTransitionSet())
      config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
Franck Dary's avatar
Franck Dary committed
    auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config);
    config.setAppliableTransitions(appliableTransitions);
      machine.getClassifier()->getNN()->extractContext(config);
    } catch(std::exception & e)
    {
      util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
    }

    Transition * goldTransition = nullptr;
    goldTransition = machine.getTransitionSet().getBestAppliableTransition(config, appliableTransitions);
      
    if (!goldTransition)
    {
      config.printForDebug(stderr);
      util::myThrow("No transition appliable !");
    }

    goldTransition->apply(config);
    config.addToHistory(goldTransition->getName());

    auto movement = config.getStrategy().getMovement(config, goldTransition->getName());
Franck Dary's avatar
Franck Dary committed
    if (debug)
      fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", goldTransition->getName(), movement.first, movement.second);
    if (movement == Strategy::endMovement)
      break;

    config.setState(movement.first);
    machine.getClassifier()->setState(movement.first);
    config.moveWordIndexRelaxed(movement.second);

    if (config.needsUpdate())
      config.update();
  }
}

Franck Dary's avatar
Franck Dary committed
Trainer::TrainAction Trainer::str2TrainAction(const std::string & s)
{
  if (s == "ExtractGold")
    return TrainAction::ExtractGold;
  if (s == "ExtractDynamic")
    return TrainAction::ExtractDynamic;
  if (s == "DeleteExamples")
    return TrainAction::DeleteExamples;
  if (s == "ResetOptimizer")
    return TrainAction::ResetOptimizer;
  if (s == "ResetParameters")
    return TrainAction::ResetParameters;
  if (s == "Save")
    return TrainAction::Save;

  util::myThrow(fmt::format("unknown TrainAction '{}'", s));

  return TrainAction::ExtractGold;
}