Skip to content
Snippets Groups Projects
Trainer.cpp 3.44 KiB
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #include "Trainer.hpp"
    #include "SubConfig.hpp"
    
    Trainer::Trainer(ReadingMachine & machine) : machine(machine)
    {
    }
    
    
    Franck Dary's avatar
    Franck Dary committed
    void Trainer::createDataset(SubConfig & config, bool debug)
    
    Franck Dary's avatar
    Franck Dary committed
    {
    
      config.addPredicted(machine.getPredicted());
    
    Franck Dary's avatar
    Franck Dary committed
      config.setState(machine.getStrategy().getInitialState());
    
      std::vector<torch::Tensor> contexts;
      std::vector<torch::Tensor> classes;
    
      while (true)
      {
    
    Franck Dary's avatar
    Franck Dary committed
        if (debug)
          config.printForDebug(stderr);
    
    
    Franck Dary's avatar
    Franck Dary committed
        auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
        if (!transition)
    
        {
          config.printForDebug(stderr);
    
    Franck Dary's avatar
    Franck Dary committed
          util::myThrow("No transition appliable !");
    
    Franck Dary's avatar
    Franck Dary committed
    
    
        auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
    
    Franck Dary's avatar
    Franck Dary committed
        contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
    
    Franck Dary's avatar
    Franck Dary committed
    
        int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
        auto gold = torch::zeros(1, at::kLong);
        gold[0] = goldIndex;
    
        classes.emplace_back(gold);
    
        transition->apply(config);
        config.addToHistory(transition->getName());
    
        auto movement = machine.getStrategy().getMovement(config, transition->getName());
    
        if (debug)
          fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
    
    Franck Dary's avatar
    Franck Dary committed
        if (movement == Strategy::endMovement)
          break;
    
        config.setState(movement.first);
        if (!config.moveWordIndex(movement.second))
    
        {
          config.printForDebug(stderr);
          util::myThrow(fmt::format("Cannot move word index by {}", movement.second));
        }
    
    Franck Dary's avatar
    Franck Dary committed
    
        if (config.needsUpdate())
          config.update();
      }
    
    
    Franck Dary's avatar
    Franck Dary committed
      nbExamples = classes.size();
    
      dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
    
    Franck Dary's avatar
    Franck Dary committed
    
    
      optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.001).beta1(0.9).beta2(0.999)));
    
    Franck Dary's avatar
    Franck Dary committed
    }
    
    
    Franck Dary's avatar
    Franck Dary committed
    float Trainer::epoch(bool printAdvancement)
    
    Franck Dary's avatar
    Franck Dary committed
    {
      constexpr int printInterval = 2000;
    
      int nbExamplesProcessed = 0;
    
    Franck Dary's avatar
    Franck Dary committed
      float totalLoss = 0.0;
      float lossSoFar = 0.0;
      int currentBatchNumber = 0;
    
    
      auto lossFct = torch::nn::CrossEntropyLoss();
    
    
      auto pastTime = std::chrono::high_resolution_clock::now();
    
    
    Franck Dary's avatar
    Franck Dary committed
      for (auto & batch : *dataLoader)
      {
    
    Franck Dary's avatar
    Franck Dary committed
    
        auto data = batch.data;
        auto labels = batch.target.squeeze();
    
        auto prediction = machine.getClassifier()->getNN()(data);
    
    
    Franck Dary's avatar
    Franck Dary committed
        labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0));
    
    
        auto loss = lossFct(prediction, labels);
    
    Franck Dary's avatar
    Franck Dary committed
        try
        {
          totalLoss += loss.item<float>();
          lossSoFar += loss.item<float>();
        } catch(std::exception & e) {util::myThrow(e.what());}
    
    Franck Dary's avatar
    Franck Dary committed
    
        loss.backward();
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
        if (printAdvancement)
    
    Franck Dary's avatar
    Franck Dary committed
        {
    
          nbExamplesProcessed += labels.size(0);
    
    Franck Dary's avatar
    Franck Dary committed
    
          ++currentBatchNumber;
    
          if (nbExamplesProcessed >= printInterval)
    
    Franck Dary's avatar
    Franck Dary committed
          {
    
            auto actualTime = std::chrono::high_resolution_clock::now();
            double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
            pastTime = actualTime;
            fmt::print(stderr, "\rcurrent epoch : {:6.2f}% loss={:<7.3f} speed={:<5}ex/s", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar, (int)(nbExamplesProcessed/seconds));
    
    Franck Dary's avatar
    Franck Dary committed
            lossSoFar = 0;
    
            nbExamplesProcessed = 0;
    
    Franck Dary's avatar
    Franck Dary committed
          }
    
    Franck Dary's avatar
    Franck Dary committed
        }
      }
    
      return totalLoss;
    }