Skip to content
Snippets Groups Projects
Trainer.cpp 3.16 KiB
#include "Trainer.hpp"
#include "SubConfig.hpp"

Trainer::Trainer(ReadingMachine & machine) : machine(machine)
{
}

void Trainer::createDataset(SubConfig & config, bool debug)
{
  config.addPredicted(machine.getPredicted());
  config.setState(machine.getStrategy().getInitialState());

  std::vector<torch::Tensor> contexts;
  std::vector<torch::Tensor> classes;

  while (true)
  {
    if (debug)
      config.printForDebug(stderr);

    auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
    if (!transition)
    {
      config.printForDebug(stderr);
      util::myThrow("No transition appliable !");
    }

    auto context = config.extractContext(5,5,machine.getDict(config.getState()));
    contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());

    int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
    auto gold = torch::zeros(1, at::kLong);
    gold[0] = goldIndex;

    classes.emplace_back(gold);

    transition->apply(config);
    config.addToHistory(transition->getName());

    auto movement = machine.getStrategy().getMovement(config, transition->getName());
    if (debug)
      fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
    if (movement == Strategy::endMovement)
      break;

    config.setState(movement.first);
    if (!config.moveWordIndex(movement.second))
    {
      config.printForDebug(stderr);
      util::myThrow(fmt::format("Cannot move word index by {}", movement.second));
    }

    if (config.needsUpdate())
      config.update();
  }

  nbExamples = classes.size();

  dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));

  denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5)));
  sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5))); 
}

float Trainer::epoch(bool printAdvancement)
{
  constexpr int printInterval = 2000;
  float totalLoss = 0.0;
  float lossSoFar = 0.0;
  int nbExamplesUntilPrint = printInterval;
  int currentBatchNumber = 0;

  for (auto & batch : *dataLoader)
  {
    denseOptimizer->zero_grad();
    sparseOptimizer->zero_grad();

    auto data = batch.data;
    auto labels = batch.target.squeeze();

    auto prediction = machine.getClassifier()->getNN()(data);

    auto loss = torch::nll_loss(torch::log(prediction), labels);
    totalLoss += loss.item<float>();
    lossSoFar += loss.item<float>();

    loss.backward();
    denseOptimizer->step();
    sparseOptimizer->step();

    if (printAdvancement)
    {
      nbExamplesUntilPrint -= labels.size(0);

      ++currentBatchNumber;
      if (nbExamplesUntilPrint <= 0)
      {
        nbExamplesUntilPrint = printInterval;
        fmt::print(stderr, "\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar);
        lossSoFar = 0;
      }
    }
  }

  return totalLoss;
}