Skip to content
Snippets Groups Projects
Trainer.cpp 3.16 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#include "Trainer.hpp"
#include "SubConfig.hpp"

Trainer::Trainer(ReadingMachine & machine) : machine(machine)
{
}

Franck Dary's avatar
Franck Dary committed
void Trainer::createDataset(SubConfig & config, bool debug)
Franck Dary's avatar
Franck Dary committed
{
  config.addPredicted(machine.getPredicted());
Franck Dary's avatar
Franck Dary committed
  config.setState(machine.getStrategy().getInitialState());

  std::vector<torch::Tensor> contexts;
  std::vector<torch::Tensor> classes;

  while (true)
  {
Franck Dary's avatar
Franck Dary committed
    if (debug)
      config.printForDebug(stderr);

Franck Dary's avatar
Franck Dary committed
    auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
    if (!transition)
    {
      config.printForDebug(stderr);
Franck Dary's avatar
Franck Dary committed
      util::myThrow("No transition appliable !");
Franck Dary's avatar
Franck Dary committed

    auto context = config.extractContext(5,5,machine.getDict(config.getState()));
Franck Dary's avatar
Franck Dary committed
    contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
Franck Dary's avatar
Franck Dary committed

    int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
    auto gold = torch::zeros(1, at::kLong);
    gold[0] = goldIndex;

    classes.emplace_back(gold);

    transition->apply(config);
    config.addToHistory(transition->getName());

    auto movement = machine.getStrategy().getMovement(config, transition->getName());
    if (debug)
      fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
Franck Dary's avatar
Franck Dary committed
    if (movement == Strategy::endMovement)
      break;

    config.setState(movement.first);
    if (!config.moveWordIndex(movement.second))
    {
      config.printForDebug(stderr);
      util::myThrow(fmt::format("Cannot move word index by {}", movement.second));
    }
Franck Dary's avatar
Franck Dary committed

    if (config.needsUpdate())
      config.update();
  }

Franck Dary's avatar
Franck Dary committed
  nbExamples = classes.size();

  dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
Franck Dary's avatar
Franck Dary committed

  denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5)));
  sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5))); 
}

Franck Dary's avatar
Franck Dary committed
float Trainer::epoch(bool printAdvancement)
Franck Dary's avatar
Franck Dary committed
{
  constexpr int printInterval = 2000;
  float totalLoss = 0.0;
  float lossSoFar = 0.0;
  int nbExamplesUntilPrint = printInterval;
  int currentBatchNumber = 0;

  for (auto & batch : *dataLoader)
  {
    denseOptimizer->zero_grad();
    sparseOptimizer->zero_grad();

    auto data = batch.data;
    auto labels = batch.target.squeeze();

    auto prediction = machine.getClassifier()->getNN()(data);

    auto loss = torch::nll_loss(torch::log(prediction), labels);
    totalLoss += loss.item<float>();
    lossSoFar += loss.item<float>();

    loss.backward();
    denseOptimizer->step();
    sparseOptimizer->step();

Franck Dary's avatar
Franck Dary committed
    if (printAdvancement)
Franck Dary's avatar
Franck Dary committed
    {
Franck Dary's avatar
Franck Dary committed
      nbExamplesUntilPrint -= labels.size(0);

      ++currentBatchNumber;
      if (nbExamplesUntilPrint <= 0)
      {
        nbExamplesUntilPrint = printInterval;
        fmt::print(stderr, "\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar);
        lossSoFar = 0;
      }
Franck Dary's avatar
Franck Dary committed
    }
  }

  return totalLoss;
}