Skip to content
Snippets Groups Projects
Trainer.cpp 2.73 KiB
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #include "Trainer.hpp"
    #include "SubConfig.hpp"
    
    Trainer::Trainer(ReadingMachine & machine) : machine(machine)
    {
    }
    
    void Trainer::createDataset(SubConfig & config)
    {
      config.setState(machine.getStrategy().getInitialState());
    
      std::vector<torch::Tensor> contexts;
      std::vector<torch::Tensor> classes;
    
      while (true)
      {
        auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
        if (!transition)
          util::myThrow("No transition appliable !");
    
        auto context = config.extractContext(5,5,machine.getDict(config.getState()));
        contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
    
        int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
        auto gold = torch::zeros(1, at::kLong);
        gold[0] = goldIndex;
    
        classes.emplace_back(gold);
    
        transition->apply(config);
        config.addToHistory(transition->getName());
    
        auto movement = machine.getStrategy().getMovement(config, transition->getName());
        if (movement == Strategy::endMovement)
          break;
    
        config.setState(movement.first);
        if (!config.moveWordIndex(movement.second))
          util::myThrow("Cannot move word index !");
    
        if (config.needsUpdate())
          config.update();
      }
    
    
    Franck Dary's avatar
    Franck Dary committed
      nbExamples = classes.size();
    
      dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
    
    Franck Dary's avatar
    Franck Dary committed
    
      denseOptimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5)));
      sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5))); 
    }
    
    
    Franck Dary's avatar
    Franck Dary committed
    float Trainer::epoch()
    {
      constexpr int printInterval = 2000;
      float totalLoss = 0.0;
      float lossSoFar = 0.0;
      int nbExamplesUntilPrint = printInterval;
      int currentBatchNumber = 0;
    
      for (auto & batch : *dataLoader)
      {
        denseOptimizer->zero_grad();
        sparseOptimizer->zero_grad();
    
        auto data = batch.data;
        auto labels = batch.target.squeeze();
    
        auto prediction = machine.getClassifier()->getNN()(data);
    
        auto loss = torch::nll_loss(torch::log(prediction), labels);
        totalLoss += loss.item<float>();
        lossSoFar += loss.item<float>();
    
        loss.backward();
        denseOptimizer->step();
        sparseOptimizer->step();
    
        nbExamplesUntilPrint -= labels.size(0);
    
        ++currentBatchNumber;
        if (nbExamplesUntilPrint <= 0)
        {
          nbExamplesUntilPrint = printInterval;
          fmt::print("\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar);
          std::fflush(stdout);
          lossSoFar = 0;
        }
      }
    
      return totalLoss;
    }