Skip to content
Snippets Groups Projects
Trainer.cpp 2.88 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#include "Trainer.hpp"
#include "SubConfig.hpp"

Trainer::Trainer(ReadingMachine & machine) : machine(machine)
{
}

void Trainer::createDataset(SubConfig & config)
{
  config.setState(machine.getStrategy().getInitialState());

  std::vector<torch::Tensor> contexts;
  std::vector<torch::Tensor> classes;

  while (true)
  {
    auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
    if (!transition)
    {
      config.printForDebug(stderr);
Franck Dary's avatar
Franck Dary committed
      util::myThrow("No transition appliable !");
Franck Dary's avatar
Franck Dary committed

Franck Dary's avatar
Franck Dary committed
    //TODO : check if clone is mandatory
Franck Dary's avatar
Franck Dary committed
    auto context = config.extractContext(5,5,machine.getDict(config.getState()));
    contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());

    int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
    auto gold = torch::zeros(1, at::kLong);
    gold[0] = goldIndex;

    classes.emplace_back(gold);

    transition->apply(config);
    config.addToHistory(transition->getName());

    auto movement = machine.getStrategy().getMovement(config, transition->getName());
    if (movement == Strategy::endMovement)
      break;

    config.setState(movement.first);
    if (!config.moveWordIndex(movement.second))
    {
      config.printForDebug(stderr);
      util::myThrow(fmt::format("Cannot move word index by {}", movement.second));
    }
Franck Dary's avatar
Franck Dary committed

    if (config.needsUpdate())
      config.update();
  }

Franck Dary's avatar
Franck Dary committed
  nbExamples = classes.size();

  dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
Franck Dary's avatar
Franck Dary committed

  denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5)));
  sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5))); 
}

Franck Dary's avatar
Franck Dary committed
float Trainer::epoch()
{
  constexpr int printInterval = 2000;
  float totalLoss = 0.0;
  float lossSoFar = 0.0;
  int nbExamplesUntilPrint = printInterval;
  int currentBatchNumber = 0;

  for (auto & batch : *dataLoader)
  {
    denseOptimizer->zero_grad();
    sparseOptimizer->zero_grad();

    auto data = batch.data;
    auto labels = batch.target.squeeze();

    auto prediction = machine.getClassifier()->getNN()(data);

    auto loss = torch::nll_loss(torch::log(prediction), labels);
    totalLoss += loss.item<float>();
    lossSoFar += loss.item<float>();

    loss.backward();
    denseOptimizer->step();
    sparseOptimizer->step();

    nbExamplesUntilPrint -= labels.size(0);

    ++currentBatchNumber;
    if (nbExamplesUntilPrint <= 0)
    {
      nbExamplesUntilPrint = printInterval;
Franck Dary's avatar
Franck Dary committed
      fmt::print(stderr, "\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar);
Franck Dary's avatar
Franck Dary committed
      lossSoFar = 0;
    }
  }

  return totalLoss;
}