Skip to content
Snippets Groups Projects
Trainer.cpp 5.51 KiB
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #include "Trainer.hpp"
    #include "SubConfig.hpp"
    
    Trainer::Trainer(ReadingMachine & machine) : machine(machine)
    {
    }
    
    
    Franck Dary's avatar
    Franck Dary committed
    void Trainer::createDataset(SubConfig & config, bool debug)
    
    Franck Dary's avatar
    Franck Dary committed
    {
    
    Franck Dary's avatar
    Franck Dary committed
      machine.trainMode(true);
    
      std::vector<torch::Tensor> contexts;
      std::vector<torch::Tensor> classes;
    
      extractExamples(config, debug, contexts, classes);
    
    Franck Dary's avatar
    Franck Dary committed
    
    
      nbExamples = classes.size();
    
      dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
    
    
      optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.0005).amsgrad(true).beta1(0.9).beta2(0.999)));
    
    }
    
    void Trainer::createDevDataset(SubConfig & config, bool debug)
    {
    
    Franck Dary's avatar
    Franck Dary committed
      machine.trainMode(false);
    
    Franck Dary's avatar
    Franck Dary committed
      std::vector<torch::Tensor> contexts;
      std::vector<torch::Tensor> classes;
    
    
      extractExamples(config, debug, contexts, classes);
    
      devDataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
    }
    
    void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes)
    {
    
      fmt::print(stderr, "[{}] Starting to extract examples\n", util::getTime());
    
    
      config.addPredicted(machine.getPredicted());
      config.setState(machine.getStrategy().getInitialState());
    
    
    Franck Dary's avatar
    Franck Dary committed
      while (true)
      {
    
    Franck Dary's avatar
    Franck Dary committed
        if (debug)
          config.printForDebug(stderr);
    
    
    Franck Dary's avatar
    Franck Dary committed
        config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
    
    
    Franck Dary's avatar
    Franck Dary committed
        auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
        if (!transition)
    
        {
          config.printForDebug(stderr);
    
    Franck Dary's avatar
    Franck Dary committed
          util::myThrow("No transition appliable !");
    
    Franck Dary's avatar
    Franck Dary committed
    
    
        std::vector<std::vector<long>> context;
    
    
          context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
          for (auto & element : context)
    
            contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone());
    
        } catch(std::exception & e)
        {
          util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
        }
    
    Franck Dary's avatar
    Franck Dary committed
    
        int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
    
        auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
    
    Franck Dary's avatar
    Franck Dary committed
        gold[0] = goldIndex;
    
    
        for (auto & element : context)
          classes.emplace_back(gold);
    
    Franck Dary's avatar
    Franck Dary committed
    
        transition->apply(config);
        config.addToHistory(transition->getName());
    
        auto movement = machine.getStrategy().getMovement(config, transition->getName());
    
        if (debug)
          fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
    
    Franck Dary's avatar
    Franck Dary committed
        if (movement == Strategy::endMovement)
          break;
    
        config.setState(movement.first);
        if (!config.moveWordIndex(movement.second))
    
        {
          config.printForDebug(stderr);
          util::myThrow(fmt::format("Cannot move word index by {}", movement.second));
        }
    
    Franck Dary's avatar
    Franck Dary committed
    
        if (config.needsUpdate())
          config.update();
      }
    
    
      fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(classes.size()));
    
    Franck Dary's avatar
    Franck Dary committed
    }
    
    
    float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement)
    
    Franck Dary's avatar
    Franck Dary committed
    {
    
    Franck Dary's avatar
    Franck Dary committed
      constexpr int printInterval = 50;
    
      int nbExamplesProcessed = 0;
    
    Franck Dary's avatar
    Franck Dary committed
      float totalLoss = 0.0;
      float lossSoFar = 0.0;
      int currentBatchNumber = 0;
    
    
      torch::AutoGradMode useGrad(train);
    
    Franck Dary's avatar
    Franck Dary committed
      machine.trainMode(train);
    
      auto lossFct = torch::nn::CrossEntropyLoss();
    
    
      auto pastTime = std::chrono::high_resolution_clock::now();
    
    
      for (auto & batch : *loader)
    
    Franck Dary's avatar
    Franck Dary committed
      {
    
        if (train)
          optimizer->zero_grad();
    
    Franck Dary's avatar
    Franck Dary committed
    
        auto data = batch.data;
        auto labels = batch.target.squeeze();
    
        auto prediction = machine.getClassifier()->getNN()(data);
    
    Franck Dary's avatar
    Franck Dary committed
        if (prediction.dim() == 1)
          prediction = prediction.unsqueeze(0);
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
        labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0));
    
    
        auto loss = lossFct(prediction, labels);
    
    Franck Dary's avatar
    Franck Dary committed
        try
        {
          totalLoss += loss.item<float>();
          lossSoFar += loss.item<float>();
        } catch(std::exception & e) {util::myThrow(e.what());}
    
    Franck Dary's avatar
    Franck Dary committed
    
    
        if (train)
        {
          loss.backward();
          optimizer->step();
        }
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
        if (printAdvancement)
    
    Franck Dary's avatar
    Franck Dary committed
        {
    
          nbExamplesProcessed += labels.size(0);
    
    Franck Dary's avatar
    Franck Dary committed
    
          ++currentBatchNumber;
    
          if (nbExamplesProcessed >= printInterval)
    
    Franck Dary's avatar
    Franck Dary committed
          {
    
            auto actualTime = std::chrono::high_resolution_clock::now();
            double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
            pastTime = actualTime;
    
    Franck Dary's avatar
    Franck Dary committed
              fmt::print(stderr, "\r{:80}\rcurrent epoch : {:6.2f}% loss={:<7.3f} speed={:<6}ex/s", "", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar, (int)(nbExamplesProcessed/seconds));
    
    Franck Dary's avatar
    Franck Dary committed
              fmt::print(stderr, "\r{:80}\reval on dev : loss={:<7.3f} speed={:<6}ex/s", "", lossSoFar, (int)(nbExamplesProcessed/seconds));
    
    Franck Dary's avatar
    Franck Dary committed
            lossSoFar = 0;
    
            nbExamplesProcessed = 0;
    
    Franck Dary's avatar
    Franck Dary committed
          }
    
    Franck Dary's avatar
    Franck Dary committed
        }
      }
    
      return totalLoss;
    }
    
    
    float Trainer::epoch(bool printAdvancement)
    {
      return processDataset(dataLoader, true, printAdvancement);
    }
    
    float Trainer::evalOnDev(bool printAdvancement)
    {
      return processDataset(devDataLoader, false, printAdvancement);
    }
    
    
    void Trainer::loadOptimizer(std::filesystem::path path)
    {
      torch::load(*optimizer, path);
    }
    
    void Trainer::saveOptimizer(std::filesystem::path path)
    {
      torch::save(*optimizer, path);
    }