Skip to content
Snippets Groups Projects
Trainer.cpp 8.1 KiB
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #include "Trainer.hpp"
    #include "SubConfig.hpp"
    
    
    Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize)
    
    Franck Dary's avatar
    Franck Dary committed
    {
    }
    
    
    void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
    
    Franck Dary's avatar
    Franck Dary committed
    {
    
      SubConfig config(goldConfig, goldConfig.getNbLines());
    
      machine.trainMode(true);
    
      extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
      trainDataset.reset(new Dataset(dir));
    
    Franck Dary's avatar
    Franck Dary committed
    
    
      dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
    
    void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
    
      SubConfig config(goldConfig, goldConfig.getNbLines());
    
    Franck Dary's avatar
    Franck Dary committed
    
    
      machine.trainMode(false);
    
      extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
      devDataset.reset(new Dataset(dir));
    
      devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
    
    void Trainer::saveExamples(std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes, int & lastSavedIndex, int & currentExampleIndex, std::filesystem::path dir)
    
      auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
      auto filename = fmt::format("{}-{}.tensor", lastSavedIndex, currentExampleIndex-1);
      torch::save(tensorToSave, dir/filename);
      lastSavedIndex = currentExampleIndex;
      contexts.clear();
      classes.clear();
    }
    
    void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval)
    {
      torch::AutoGradMode useGrad(false);
    
      machine.setDictsState(Dict::State::Open);
    
    
      int maxNbExamplesPerFile = 250000;
      int currentExampleIndex = 0;
      int lastSavedIndex = 0;
      std::vector<torch::Tensor> contexts;
      std::vector<torch::Tensor> classes;
    
      std::filesystem::create_directories(dir);
    
      config.addPredicted(machine.getPredicted());
      config.setState(machine.getStrategy().getInitialState());
    
      machine.getStrategy().reset();
    
      auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch);
      bool mustExtract = !std::filesystem::exists(currentEpochAllExtractedFile);
      if (epoch != 0 and (dynamicOracleInterval == -1 or epoch % dynamicOracleInterval))
        mustExtract = false;
    
      if (!mustExtract)
        return;
    
      bool dynamicOracle = epoch != 0;
    
      fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");
    
      for (auto & entry : std::filesystem::directory_iterator(dir))
        if (entry.is_regular_file())
          std::filesystem::remove(entry.path());
    
    Franck Dary's avatar
    Franck Dary committed
      while (true)
      {
    
    Franck Dary's avatar
    Franck Dary committed
        if (debug)
          config.printForDebug(stderr);
    
    
        if (machine.hasSplitWordTransitionSet())
          config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
    
        std::vector<std::vector<long>> context;
    
    
          context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
          for (auto & element : context)
    
            contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone());
    
        } catch(std::exception & e)
        {
          util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
        }
    
    Franck Dary's avatar
    Franck Dary committed
    
    
        Transition * transition = nullptr;
          
        if (dynamicOracle and config.getState() != "tokenizer")
        {
          auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
          auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
      
          int chosenTransition = -1;
          float bestScore = std::numeric_limits<float>::min();
    
          for (unsigned int i = 0; i < prediction.size(0); i++)
          {
            float score = prediction[i].item<float>();
            if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config))
            {
              chosenTransition = i;
              bestScore = score;
            }
          }
    
          transition = machine.getTransitionSet().getTransition(chosenTransition);
        }
        else
        {
          transition = machine.getTransitionSet().getBestAppliableTransition(config);
        }
    
        if (!transition)
        {
          config.printForDebug(stderr);
          util::myThrow("No transition appliable !");
        }
    
    
    Franck Dary's avatar
    Franck Dary committed
        int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
    
        auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
    
    Franck Dary's avatar
    Franck Dary committed
        gold[0] = goldIndex;
    
    
        currentExampleIndex += context.size();
        classes.insert(classes.end(), context.size(), gold);
    
    
        if (currentExampleIndex-lastSavedIndex >= maxNbExamplesPerFile)
          saveExamples(contexts, classes, lastSavedIndex, currentExampleIndex, dir);
    
    Franck Dary's avatar
    Franck Dary committed
    
        transition->apply(config);
        config.addToHistory(transition->getName());
    
        auto movement = machine.getStrategy().getMovement(config, transition->getName());
    
        if (debug)
          fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
    
    Franck Dary's avatar
    Franck Dary committed
        if (movement == Strategy::endMovement)
          break;
    
        config.setState(movement.first);
    
    Franck Dary's avatar
    Franck Dary committed
        config.moveWordIndexRelaxed(movement.second);
    
    Franck Dary's avatar
    Franck Dary committed
    
        if (config.needsUpdate())
          config.update();
      }
    
      if (!contexts.empty())
        saveExamples(contexts, classes, lastSavedIndex, currentExampleIndex, dir);
    
      std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
      if (!f)
        util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
      std::fclose(f);
    
    
      machine.saveDicts();
    
    
      fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(currentExampleIndex));
    
    Franck Dary's avatar
    Franck Dary committed
    }
    
    
    float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples)
    
    Franck Dary's avatar
    Franck Dary committed
    {
    
    Franck Dary's avatar
    Franck Dary committed
      constexpr int printInterval = 50;
    
      int nbExamplesProcessed = 0;
    
      int totalNbExamplesProcessed = 0;
    
    Franck Dary's avatar
    Franck Dary committed
      float totalLoss = 0.0;
      float lossSoFar = 0.0;
    
    
      torch::AutoGradMode useGrad(train);
    
    Franck Dary's avatar
    Franck Dary committed
      machine.trainMode(train);
    
      machine.setDictsState(Dict::State::Closed);
    
      auto lossFct = torch::nn::CrossEntropyLoss();
    
    
      auto pastTime = std::chrono::high_resolution_clock::now();
    
    
      for (auto & batch : *loader)
    
    Franck Dary's avatar
    Franck Dary committed
      {
    
          machine.getClassifier()->getOptimizer().zero_grad();
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
    
        auto prediction = machine.getClassifier()->getNN()(data);
    
    Franck Dary's avatar
    Franck Dary committed
        if (prediction.dim() == 1)
          prediction = prediction.unsqueeze(0);
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
        labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0));
    
    
        auto loss = lossFct(prediction, labels);
    
    Franck Dary's avatar
    Franck Dary committed
        try
        {
          totalLoss += loss.item<float>();
          lossSoFar += loss.item<float>();
        } catch(std::exception & e) {util::myThrow(e.what());}
    
    Franck Dary's avatar
    Franck Dary committed
    
    
          machine.getClassifier()->getOptimizer().step();
    
    Franck Dary's avatar
    Franck Dary committed
    
    
        totalNbExamplesProcessed += torch::numel(labels);
    
    
    Franck Dary's avatar
    Franck Dary committed
        if (printAdvancement)
    
    Franck Dary's avatar
    Franck Dary committed
        {
    
          nbExamplesProcessed += torch::numel(labels);
    
    Franck Dary's avatar
    Franck Dary committed
    
    
          if (nbExamplesProcessed >= printInterval)
    
    Franck Dary's avatar
    Franck Dary committed
          {
    
            auto actualTime = std::chrono::high_resolution_clock::now();
            double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
            pastTime = actualTime;
    
            auto speed = (int)(nbExamplesProcessed/seconds);
            auto progression = 100.0*totalNbExamplesProcessed / nbExamples;
            auto statusStr = fmt::format("{:6.2f}% loss={:<7.3f} speed={:<6}ex/s", progression, lossSoFar, speed);
    
              fmt::print(stderr, "\r{:80}\rtraining : {}", "", statusStr);
    
              fmt::print(stderr, "\r{:80}\reval on dev : {}", "", statusStr);
    
    Franck Dary's avatar
    Franck Dary committed
            lossSoFar = 0;
    
            nbExamplesProcessed = 0;
    
    Franck Dary's avatar
    Franck Dary committed
          }
    
      return totalLoss / nbExamples;
    
    Franck Dary's avatar
    Franck Dary committed
    }
    
    
    float Trainer::epoch(bool printAdvancement)
    {
    
      return processDataset(dataLoader, true, printAdvancement, trainDataset->size().value());
    
    }
    
    float Trainer::evalOnDev(bool printAdvancement)
    {
    
      return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value());