Skip to content
Snippets Groups Projects
Trainer.cpp 3.52 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#include "Trainer.hpp"
#include "SubConfig.hpp"

Trainer::Trainer(ReadingMachine & machine) : machine(machine)
{
}

Franck Dary's avatar
Franck Dary committed
void Trainer::createDataset(SubConfig & config, bool debug)
Franck Dary's avatar
Franck Dary committed
{
  config.addPredicted(machine.getPredicted());
Franck Dary's avatar
Franck Dary committed
  config.setState(machine.getStrategy().getInitialState());

  std::vector<torch::Tensor> contexts;
  std::vector<torch::Tensor> classes;

  while (true)
  {
Franck Dary's avatar
Franck Dary committed
    if (debug)
      config.printForDebug(stderr);

Franck Dary's avatar
Franck Dary committed
    auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
    if (!transition)
    {
      config.printForDebug(stderr);
Franck Dary's avatar
Franck Dary committed
      util::myThrow("No transition appliable !");
Franck Dary's avatar
Franck Dary committed

    auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
Franck Dary's avatar
Franck Dary committed
    contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
Franck Dary's avatar
Franck Dary committed

    int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
    auto gold = torch::zeros(1, at::kLong);
    gold[0] = goldIndex;

    classes.emplace_back(gold);

    transition->apply(config);
    config.addToHistory(transition->getName());

    auto movement = machine.getStrategy().getMovement(config, transition->getName());
    if (debug)
      fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
Franck Dary's avatar
Franck Dary committed
    if (movement == Strategy::endMovement)
      break;

    config.setState(movement.first);
    if (!config.moveWordIndex(movement.second))
    {
      config.printForDebug(stderr);
      util::myThrow(fmt::format("Cannot move word index by {}", movement.second));
    }
Franck Dary's avatar
Franck Dary committed

    if (config.needsUpdate())
      config.update();
  }

Franck Dary's avatar
Franck Dary committed
  nbExamples = classes.size();

  dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
Franck Dary's avatar
Franck Dary committed

  optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.001).amsgrad(true).beta1(0.9).beta2(0.999)));
Franck Dary's avatar
Franck Dary committed
}

Franck Dary's avatar
Franck Dary committed
float Trainer::epoch(bool printAdvancement)
Franck Dary's avatar
Franck Dary committed
{
Franck Dary's avatar
Franck Dary committed
  constexpr int printInterval = 50;
  int nbExamplesProcessed = 0;
Franck Dary's avatar
Franck Dary committed
  float totalLoss = 0.0;
  float lossSoFar = 0.0;
  int currentBatchNumber = 0;

  auto lossFct = torch::nn::CrossEntropyLoss();

  auto pastTime = std::chrono::high_resolution_clock::now();

Franck Dary's avatar
Franck Dary committed
  for (auto & batch : *dataLoader)
  {
Franck Dary's avatar
Franck Dary committed

    auto data = batch.data;
    auto labels = batch.target.squeeze();

    auto prediction = machine.getClassifier()->getNN()(data);
Franck Dary's avatar
Franck Dary committed
    if (prediction.dim() == 1)
      prediction = prediction.unsqueeze(0);
Franck Dary's avatar
Franck Dary committed

Franck Dary's avatar
Franck Dary committed
    labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0));

    auto loss = lossFct(prediction, labels);
Franck Dary's avatar
Franck Dary committed
    try
    {
      totalLoss += loss.item<float>();
      lossSoFar += loss.item<float>();
    } catch(std::exception & e) {util::myThrow(e.what());}
Franck Dary's avatar
Franck Dary committed

    loss.backward();
Franck Dary's avatar
Franck Dary committed

Franck Dary's avatar
Franck Dary committed
    if (printAdvancement)
Franck Dary's avatar
Franck Dary committed
    {
      nbExamplesProcessed += labels.size(0);
Franck Dary's avatar
Franck Dary committed

      ++currentBatchNumber;
      if (nbExamplesProcessed >= printInterval)
Franck Dary's avatar
Franck Dary committed
      {
        auto actualTime = std::chrono::high_resolution_clock::now();
        double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
        pastTime = actualTime;
        fmt::print(stderr, "\rcurrent epoch : {:6.2f}% loss={:<7.3f} speed={:<5}ex/s", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar, (int)(nbExamplesProcessed/seconds));
Franck Dary's avatar
Franck Dary committed
        lossSoFar = 0;
        nbExamplesProcessed = 0;
Franck Dary's avatar
Franck Dary committed
      }
Franck Dary's avatar
Franck Dary committed
    }
  }

  return totalLoss;
}