Skip to content
Snippets Groups Projects
Trainer.cpp 11 KiB
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #include "Trainer.hpp"
    #include "SubConfig.hpp"
    
    
    Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize)
    
    Franck Dary's avatar
    Franck Dary committed
    {
    }
    
    
    Franck Dary's avatar
    Franck Dary committed
    void Trainer::makeDataLoader(std::filesystem::path dir)
    
    Franck Dary's avatar
    Franck Dary committed
    {
    
      trainDataset.reset(new Dataset(dir));
      dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
    
    Franck Dary's avatar
    Franck Dary committed
    void Trainer::makeDevDataLoader(std::filesystem::path dir)
    {
      devDataset.reset(new Dataset(dir));
      devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
    }
    
    void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle)
    
      SubConfig config(goldConfig, goldConfig.getNbLines());
    
    Franck Dary's avatar
    Franck Dary committed
    
    
      machine.trainMode(false);
    
      machine.setDictsState(Dict::State::Open);
    
    Franck Dary's avatar
    Franck Dary committed
      extractExamples(config, debug, dir, epoch, dynamicOracle);
    
    Franck Dary's avatar
    Franck Dary committed
    void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle)
    
      int maxNbExamplesPerFile = 50000;
      std::map<std::string, Examples> examplesPerState;
    
      config.addPredicted(machine.getPredicted());
    
      config.setStrategy(machine.getStrategyDefinition());
      config.setState(config.getStrategy().getInitialState());
      machine.getClassifier()->setState(config.getState());
    
    Franck Dary's avatar
    Franck Dary committed
      auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}.{}", epoch, dynamicOracle);
    
    Franck Dary's avatar
    Franck Dary committed
      if (std::filesystem::exists(currentEpochAllExtractedFile))
    
        return;
    
      fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");
    
    
    Franck Dary's avatar
    Franck Dary committed
      while (true)
      {
    
    Franck Dary's avatar
    Franck Dary committed
        if (debug)
          config.printForDebug(stderr);
    
    
        if (machine.hasSplitWordTransitionSet())
          config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
    
    Franck Dary's avatar
    Franck Dary committed
        auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config);
        config.setAppliableTransitions(appliableTransitions);
    
        std::vector<std::vector<long>> context;
    
    
          context = machine.getClassifier()->getNN()->extractContext(config);
    
        } catch(std::exception & e)
        {
          util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
        }
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
        Transition * goldTransition = nullptr;
    
    
        goldTransition = machine.getTransitionSet().getBestAppliableTransition(config, appliableTransitions, dynamicOracle);
    
    Franck Dary's avatar
    Franck Dary committed
        if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
    
        {
          auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
          auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
      
          int chosenTransition = -1;
          float bestScore = std::numeric_limits<float>::min();
    
          for (unsigned int i = 0; i < prediction.size(0); i++)
          {
            float score = prediction[i].item<float>();
            if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config))
            {
              chosenTransition = i;
              bestScore = score;
            }
          }
    
          transition = machine.getTransitionSet().getTransition(chosenTransition);
        }
        else
        {
    
    Franck Dary's avatar
    Franck Dary committed
          transition = goldTransition;
    
    Franck Dary's avatar
    Franck Dary committed
        if (!transition or !goldTransition)
    
        {
          config.printForDebug(stderr);
          util::myThrow("No transition appliable !");
        }
    
    
    Franck Dary's avatar
    Franck Dary committed
        int goldIndex = machine.getTransitionSet().getTransitionIndex(goldTransition);
    
    Franck Dary's avatar
    Franck Dary committed
    
    
        totalNbExamples += context.size();
    
        if (totalNbExamples >= (int)safetyNbExamplesMax)
          util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));
    
        examplesPerState[config.getState()].addContext(context);
        examplesPerState[config.getState()].addClass(goldIndex);
    
    Franck Dary's avatar
    Franck Dary committed
        examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle);
    
    Franck Dary's avatar
    Franck Dary committed
    
        transition->apply(config);
        config.addToHistory(transition->getName());
    
    
        auto movement = config.getStrategy().getMovement(config, transition->getName());
    
        if (debug)
          fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
    
    Franck Dary's avatar
    Franck Dary committed
        if (movement == Strategy::endMovement)
          break;
    
        config.setState(movement.first);
    
        machine.getClassifier()->setState(movement.first);
    
    Franck Dary's avatar
    Franck Dary committed
        config.moveWordIndexRelaxed(movement.second);
    
    Franck Dary's avatar
    Franck Dary committed
    
        if (config.needsUpdate())
          config.update();
      }
    
      for (auto & it : examplesPerState)
    
    Franck Dary's avatar
    Franck Dary committed
        it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle);
    
    
      std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
      if (!f)
        util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
      std::fclose(f);
    
    
      fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(totalNbExamples));
    
    Franck Dary's avatar
    Franck Dary committed
    }
    
    
    float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples)
    
    Franck Dary's avatar
    Franck Dary committed
    {
    
    Franck Dary's avatar
    Franck Dary committed
      constexpr int printInterval = 50;
    
      int nbExamplesProcessed = 0;
    
      int totalNbExamplesProcessed = 0;
    
    Franck Dary's avatar
    Franck Dary committed
      float totalLoss = 0.0;
      float lossSoFar = 0.0;
    
    
      torch::AutoGradMode useGrad(train);
    
    Franck Dary's avatar
    Franck Dary committed
      machine.trainMode(train);
    
      machine.setDictsState(Dict::State::Closed);
    
      auto lossFct = torch::nn::CrossEntropyLoss();
    
    
      auto pastTime = std::chrono::high_resolution_clock::now();
    
    
      for (auto & batch : *loader)
    
    Franck Dary's avatar
    Franck Dary committed
      {
    
          machine.getClassifier()->getOptimizer().zero_grad();
    
    Franck Dary's avatar
    Franck Dary committed
    
    
        auto data = std::get<0>(batch);
        auto labels = std::get<1>(batch);
        auto state = std::get<2>(batch);
    
        machine.getClassifier()->setState(state);
    
    Franck Dary's avatar
    Franck Dary committed
    
        auto prediction = machine.getClassifier()->getNN()(data);
    
    Franck Dary's avatar
    Franck Dary committed
        if (prediction.dim() == 1)
          prediction = prediction.unsqueeze(0);
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
        labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0));
    
    
    Franck Dary's avatar
    Franck Dary committed
        auto loss = machine.getClassifier()->getLossMultiplier()*lossFct(prediction, labels);
    
    Franck Dary's avatar
    Franck Dary committed
        try
        {
          totalLoss += loss.item<float>();
          lossSoFar += loss.item<float>();
        } catch(std::exception & e) {util::myThrow(e.what());}
    
    Franck Dary's avatar
    Franck Dary committed
    
    
          machine.getClassifier()->getOptimizer().step();
    
    Franck Dary's avatar
    Franck Dary committed
    
    
        totalNbExamplesProcessed += torch::numel(labels);
    
    
    Franck Dary's avatar
    Franck Dary committed
        if (printAdvancement)
    
    Franck Dary's avatar
    Franck Dary committed
        {
    
          nbExamplesProcessed += torch::numel(labels);
    
    Franck Dary's avatar
    Franck Dary committed
    
    
          if (nbExamplesProcessed >= printInterval)
    
    Franck Dary's avatar
    Franck Dary committed
          {
    
            auto actualTime = std::chrono::high_resolution_clock::now();
            double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
            pastTime = actualTime;
    
            auto speed = (int)(nbExamplesProcessed/seconds);
            auto progression = 100.0*totalNbExamplesProcessed / nbExamples;
            auto statusStr = fmt::format("{:6.2f}% loss={:<7.3f} speed={:<6}ex/s", progression, lossSoFar, speed);
    
              fmt::print(stderr, "\r{:80}\rtraining : {}", "", statusStr);
    
              fmt::print(stderr, "\r{:80}\reval on dev : {}", "", statusStr);
    
    Franck Dary's avatar
    Franck Dary committed
            lossSoFar = 0;
    
            nbExamplesProcessed = 0;
    
    Franck Dary's avatar
    Franck Dary committed
          }
    
      return totalLoss / nbExamples;
    
    Franck Dary's avatar
    Franck Dary committed
    }
    
    
    float Trainer::epoch(bool printAdvancement)
    {
    
      return processDataset(dataLoader, true, printAdvancement, trainDataset->size().value());
    
    }
    
    float Trainer::evalOnDev(bool printAdvancement)
    {
    
      return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value());
    
    Franck Dary's avatar
    Franck Dary committed
    void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int epoch, bool dynamicOracle)
    
    {
      if (currentExampleIndex-lastSavedIndex < (int)threshold)
        return;
      if (contexts.empty())
        return;
    
      auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
    
    Franck Dary's avatar
    Franck Dary committed
      auto filename = fmt::format("{}_{}-{}.{}.{}.tensor", state, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle);
    
      torch::save(tensorToSave, dir/filename);
      lastSavedIndex = currentExampleIndex;
      contexts.clear();
      classes.clear();
    }
    
    void Trainer::Examples::addContext(std::vector<std::vector<long>> & context)
    {
      for (auto & element : context)
        contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone());
    
      currentExampleIndex += context.size();
    }
    
    void Trainer::Examples::addClass(int goldIndex)
    {
        auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
        gold[0] = goldIndex;
    
        while (classes.size() < contexts.size())
          classes.emplace_back(gold);
    }
    
    
    Franck Dary's avatar
    Franck Dary committed
    void Trainer::fillDicts(SubConfig & config, bool debug)
    
    {
      torch::AutoGradMode useGrad(false);
    
      config.addPredicted(machine.getPredicted());
    
      config.setStrategy(machine.getStrategyDefinition());
      config.setState(config.getStrategy().getInitialState());
      machine.getClassifier()->setState(config.getState());
    
    Franck Dary's avatar
    Franck Dary committed
        if (debug)
          config.printForDebug(stderr);
    
    
        if (machine.hasSplitWordTransitionSet())
          config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
    
    Franck Dary's avatar
    Franck Dary committed
        auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config);
        config.setAppliableTransitions(appliableTransitions);
    
          machine.getClassifier()->getNN()->extractContext(config);
    
        } catch(std::exception & e)
        {
          util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
        }
    
        Transition * goldTransition = nullptr;
    
        goldTransition = machine.getTransitionSet().getBestAppliableTransition(config, appliableTransitions);
    
          
        if (!goldTransition)
        {
          config.printForDebug(stderr);
          util::myThrow("No transition appliable !");
        }
    
        goldTransition->apply(config);
        config.addToHistory(goldTransition->getName());
    
    
        auto movement = config.getStrategy().getMovement(config, goldTransition->getName());
    
    Franck Dary's avatar
    Franck Dary committed
        if (debug)
          fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", goldTransition->getName(), movement.first, movement.second);
    
        if (movement == Strategy::endMovement)
          break;
    
        config.setState(movement.first);
        machine.getClassifier()->setState(movement.first);
        config.moveWordIndexRelaxed(movement.second);
    
        if (config.needsUpdate())
          config.update();
      }
    }
    
    
    Franck Dary's avatar
    Franck Dary committed
    Trainer::TrainAction Trainer::str2TrainAction(const std::string & s)
    {
      if (s == "ExtractGold")
        return TrainAction::ExtractGold;
      if (s == "ExtractDynamic")
        return TrainAction::ExtractDynamic;
      if (s == "DeleteExamples")
        return TrainAction::DeleteExamples;
      if (s == "ResetOptimizer")
        return TrainAction::ResetOptimizer;
      if (s == "ResetParameters")
        return TrainAction::ResetParameters;
      if (s == "Save")
        return TrainAction::Save;
    
      util::myThrow(fmt::format("unknown TrainAction '{}'", s));
    
      return TrainAction::ExtractGold;
    }