Skip to content
Snippets Groups Projects
Select Git revision
  • master default protected
  • loss
  • producer
3 results

Trainer.cpp

Blame
  • Trainer.cpp 15.93 KiB
    #include "Trainer.hpp"
    #include "SubConfig.hpp"
    #include <execution>
    
    Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize)
    {
    }
    
    void Trainer::makeDataLoader(std::filesystem::path dir)
    {
      trainDataset.reset(new Dataset(dir));
      dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
    }
    
    void Trainer::makeDevDataLoader(std::filesystem::path dir)
    {
      devDataset.reset(new Dataset(dir));
      devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
    }
    
    void Trainer::createDataset(std::vector<BaseConfig> & goldConfigs, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold, bool memcheck)
    {
      std::vector<SubConfig> configs;
      for (auto & goldConfig : goldConfigs)
        configs.emplace_back(goldConfig, goldConfig.getNbLines());
    
      machine.trainMode(false);
    
      extractExamples(configs, debug, dir, epoch, dynamicOracle, explorationThreshold, memcheck);
    
      machine.saveDicts();
    }
    
    void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold, bool memcheck)
    {
      torch::AutoGradMode useGrad(false);
    
      int maxNbExamplesPerFile = 50000;
      std::unordered_map<std::string, Examples> examplesPerState;
      std::mutex examplesMutex;
    
      std::filesystem::create_directories(dir);
    
      auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}.{}", epoch, dynamicOracle);
    
      if (std::filesystem::exists(currentEpochAllExtractedFile))
        return;
    
      fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");
    
      std::atomic<int> totalNbExamples = 0;
    
      if (memcheck)
        fmt::print(stderr, "[{}] Memory : {}\n", util::getTime(), util::getMemUsage());
    
      NeuralNetworkImpl::setDevice(torch::kCPU);
      machine.to(NeuralNetworkImpl::getDevice());
      std::for_each(std::execution::seq, configs.begin(), configs.end(),
        [this, maxNbExamplesPerFile, &examplesPerState, &totalNbExamples, debug, memcheck, dynamicOracle, explorationThreshold, dir, epoch, &examplesMutex](SubConfig & config)
        {
          config.addPredicted(machine.getPredicted());
          config.setStrategy(machine.getStrategyDefinition());
          config.setState(config.getStrategy().getInitialState());
    
          while (true)
          {
            if (debug)
              config.printForDebug(stderr);
    
            if (machine.hasSplitWordTransitionSet())
              config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
    
            auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config);
            config.setAppliableTransitions(appliableTransitions);
    
            torch::Tensor context;
    
            try
            {
              context = machine.getClassifier(config.getState())->getNN()->extractContext(config);
            } catch(std::exception & e)
            {
              util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
            }
    
            Transition * transition = nullptr;
    
            auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle);
    
            Transition * goldTransition = goldTransitions[0];
            if (config.getState() == "parser")
              goldTransitions[std::rand()%goldTransitions.size()];
    
            int nbClasses = machine.getTransitionSet(config.getState()).size();
    
            float bestScore = -std::numeric_limits<float>::max();
    
            float entropy = 0.0;
              
            if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
            {
              auto & classifier = *machine.getClassifier(config.getState());
              auto prediction = classifier.isRegression() ? classifier.getNN()->forward(context, config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(context, config.getState()).squeeze(0), 0);
              entropy  = NeuralNetworkImpl::entropy(prediction);
          
              std::vector<int> candidates;
    
              for (unsigned int i = 0; i < prediction.size(0); i++)
              {
                float score = prediction[i].item<float>();
                if (score > bestScore and appliableTransitions[i])
                  bestScore = score;
              }
    
              for (unsigned int i = 0; i < prediction.size(0); i++)
              {
                float score = prediction[i].item<float>();
                if (appliableTransitions[i] and bestScore - score <= explorationThreshold)
                  candidates.emplace_back(i);
              }
    
              transition = machine.getTransitionSet(config.getState()).getTransition(candidates[std::rand()%candidates.size()]);
    
              for (auto & trans : goldTransitions)
                if (trans == transition)
                  goldTransition = trans;
            }
            else
            {
              transition = goldTransition;
            }
    
            if (!transition or !goldTransition)
            {
              config.printForDebug(stderr);
              util::myThrow("No transition appliable !");
            }
    
            std::vector<long> goldIndexes;
            bool exampleIsBanned = machine.getClassifier(config.getState())->exampleIsBanned(config);
    
            if (machine.getClassifier(config.getState())->isRegression())
            {
              entropy = 0.0;
              auto errMessage = fmt::format("Invalid regression transition '{}'", transition->getName());
              auto splited = util::split(transition->getName(), ' ');
              if (splited.size() != 3 or splited[0] != "WRITESCORE")
                util::myThrow(errMessage);
              auto col = splited[2];
              splited = util::split(splited[1], '.');
              if (splited.size() != 2)
                util::myThrow(errMessage);
              auto object = Config::str2object(splited[0]);
              int index = std::stoi(splited[1]);
    
              float regressionTarget = std::stof(config.getConst(col, config.getRelativeWordIndex(object, index), 0));
              goldIndexes.emplace_back(util::float2long(regressionTarget));
            }
            else
            {
              for (auto & t : goldTransitions)
                goldIndexes.emplace_back(machine.getTransitionSet(config.getState()).getTransitionIndex(t));
    
            }
    
            if (!exampleIsBanned)
            {
              totalNbExamples += 1;
              if (totalNbExamples >= (int)safetyNbExamplesMax)
                util::error(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));
    
              examplesMutex.lock();
              examplesPerState[config.getState()].addContext(context);
              examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes);
              examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle);
              examplesMutex.unlock();
            }
    
            config.setChosenActionScore(bestScore);
    
            transition->apply(config, entropy);
            config.addToHistory(transition->getName());
    
            auto movement = config.getStrategy().getMovement(config, transition->getName());
            if (debug)
              fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
            if (movement == Strategy::endMovement)
              break;
    
            config.setState(movement.first);
            config.moveWordIndexRelaxed(movement.second);
    
            if (config.needsUpdate())
              config.update();
    
          } // End while true
    
        if (memcheck)
          fmt::print(stderr, "[{}] Memory : {}\n", util::getTime(), util::getMemUsage());
      }); // End for on configs
    
      for (auto & it : examplesPerState)
        it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle);
    
      NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice());
      machine.to(NeuralNetworkImpl::getDevice());
    
      std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
      if (!f)
        util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
      std::fclose(f);
    
      if (memcheck)
        fmt::print(stderr, "[{}] Memory : {}\n", util::getTime(), util::getMemUsage());
      fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(totalNbExamples));
    }
    
    float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples)
    {
      constexpr int printInterval = 50;
      int nbExamplesProcessed = 0;
      int totalNbExamplesProcessed = 0;
      float totalLoss = 0.0;
      float lossSoFar = 0.0;
    
      torch::AutoGradMode useGrad(train);
      machine.trainMode(train);
    
      auto pastTime = std::chrono::high_resolution_clock::now();
    
      for (auto & batch : *loader)
      {
        auto data = std::get<0>(batch);
        auto labels = std::get<1>(batch);
        auto state = std::get<2>(batch);
    
        if (train)
          machine.getClassifier(state)->getOptimizer().zero_grad();
    
        auto prediction = machine.getClassifier(state)->getNN()->forward(data, state);
        if (prediction.dim() == 1)
          prediction = prediction.unsqueeze(0);
    
        if (machine.getClassifier(state)->isRegression())
        {
          labels = labels.to(torch::kFloat);
          labels /= util::float2longScale;
        }
    
        auto lossParameter = machine.getClassifier(state)->getNN()->getLossParameter(state);
    
        auto loss = machine.getClassifier(state)->getLossMultiplier(state)*machine.getClassifier(state)->getLossFunction()(prediction, labels)*(1.0/torch::exp(lossParameter)) + lossParameter;
        float lossAsFloat = 0.0;
        try
        {
          lossAsFloat = loss.item<float>();
        } catch(std::exception & e) {util::myThrow(e.what());}
    
        totalLoss += lossAsFloat;
        lossSoFar += lossAsFloat;
    
        if (train)
        {
          loss.backward();
          machine.getClassifier(state)->getOptimizer().step();
        }
    
        totalNbExamplesProcessed += labels.size(0);
    
        if (printAdvancement)
        {
          nbExamplesProcessed += labels.size(0);
    
          if (nbExamplesProcessed >= printInterval)
          {
            auto actualTime = std::chrono::high_resolution_clock::now();
            double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
            pastTime = actualTime;
            auto speed = (int)(nbExamplesProcessed/seconds);
            auto progression = 100.0*totalNbExamplesProcessed / nbExamples;
            auto statusStr = fmt::format(lossSoFar/nbExamplesProcessed < 10.0 ? "{:6.2f}% loss={:<7.3f} speed={:<6}ex/s": "{:6.2f}% loss={:<7.0f} speed={:<6}ex/s", progression, lossSoFar / nbExamplesProcessed, speed);
            if (train)
              fmt::print(stderr, "\r{:80}\rtraining : {}", "", statusStr);
            else
              fmt::print(stderr, "\r{:80}\reval on dev : {}", "", statusStr);
            lossSoFar = 0;
            nbExamplesProcessed = 0;
          }
        }
      }
    
      return totalLoss / nbExamples;
    }
    
    float Trainer::epoch(bool printAdvancement)
    {
      return processDataset(dataLoader, true, printAdvancement, trainDataset->size().value());
    }
    
    float Trainer::evalOnDev(bool printAdvancement)
    {
      return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value());
    }
    
    void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int epoch, bool dynamicOracle)
    {
      if (currentExampleIndex-lastSavedIndex < (int)threshold)
        return;
      if (contexts.empty())
        return;
    
      int nbClasses = classes[0].size(0);
    
      auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
      auto filename = fmt::format("{}-{}_{}-{}.{}.{}.tensor", state, nbClasses, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle);
      torch::save(tensorToSave, dir/filename);
      lastSavedIndex = currentExampleIndex;
      contexts.clear();
      classes.clear();
    }
    
    void Trainer::Examples::addContext(torch::Tensor & context)
    {
      contexts.emplace_back(context);
    
      currentExampleIndex += 1;
    }
    
    void Trainer::Examples::addClass(const LossFunction & lossFct, int nbClasses, const std::vector<long> & goldIndexes)
    {
      auto gold = lossFct.getGoldFromClassesIndexes(nbClasses, goldIndexes);
    
      while (classes.size() < contexts.size())
        classes.emplace_back(gold);
    }
    
    Trainer::TrainAction Trainer::str2TrainAction(const std::string & s)
    {
      if (s == "ExtractGold")
        return TrainAction::ExtractGold;
      if (s == "ExtractDynamic")
        return TrainAction::ExtractDynamic;
      if (s == "DeleteExamples")
        return TrainAction::DeleteExamples;
      if (s == "ResetOptimizer")
        return TrainAction::ResetOptimizer;
      if (s == "ResetParameters")
        return TrainAction::ResetParameters;
      if (s == "Save")
        return TrainAction::Save;
    
      util::myThrow(fmt::format("unknown TrainAction '{}'", s));
    
      return TrainAction::ExtractGold;
    }
    
    void Trainer::extractActionSequence(BaseConfig & config)
    {
      config.addPredicted(machine.getPredicted());
      config.setStrategy(machine.getStrategyDefinition());
      config.setState(config.getStrategy().getInitialState());
    
      int curSeq = 0;
      int curSeqStartIndex = -1;
      int curInputIndex = 0;
      int curInputSeqSize = 0;
      int curOutputSeqSize = 0;
      int maxInputSeqSize = 0;
      int maxOutputSeqSize = 0;
      bool newSent = true;
      std::vector<std::string> transitionsIndexes;
    
      while (true)
      {
        if (config.hasCharacter(0))
          curInputIndex = config.getCharacterIndex();
        else
          curInputIndex = config.getWordIndex();
    
        if (curSeqStartIndex == -1 or newSent)
        {
          newSent = false;
          curSeqStartIndex = curInputIndex;
        }
    
        if (machine.hasSplitWordTransitionSet())
          config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
    
        auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config);
        config.setAppliableTransitions(appliableTransitions);
    
        auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true);
    
        Transition * transition = goldTransitions[0];
    
        if (machine.getClassifier(config.getState())->isRegression())
          util::myThrow("Regressions are not supported in extract action sequence mode");
    
        transitionsIndexes.push_back(fmt::format("{}", machine.getTransitionSet(config.getState()).getTransitionIndex(transition)));
        maxOutputSeqSize = std::max(maxOutputSeqSize, curOutputSeqSize++);
        curInputSeqSize = -curSeqStartIndex + curInputIndex;
        maxInputSeqSize = std::max(maxInputSeqSize, curInputSeqSize++);
        if (util::split(transition->getName(), ' ')[0] == "EOS")
          if (++curSeq % 3 == 0)
          {
            newSent = true;
            std::string curSeq = "";
            for (int i = curSeqStartIndex; i <= curInputIndex; i++)
              curSeq += config.hasCharacter(0) ? fmt::format("{}", config.getLetter(i)) : fmt::format("{} ", std::string(config.getAsFeature("FORM", i)));
            fmt::print(stdout, "{}\n{}\n\n", config.hasCharacter(0) ? curSeq : util::strip(curSeq), util::join(" ", transitionsIndexes));
            curOutputSeqSize = 0;
            curInputSeqSize = 0;
            transitionsIndexes.clear();
          }
    
        transition->apply(config);
        config.addToHistory(transition->getName());
    
        auto movement = config.getStrategy().getMovement(config, transition->getName());
        if (movement == Strategy::endMovement)
          break;
    
        config.setState(movement.first);
        config.moveWordIndexRelaxed(movement.second);
      }
    
      if (curSeqStartIndex != curInputIndex)
      {
        std::string curSeq = "";
        for (int i = curSeqStartIndex; i <= curInputIndex; i++)
          curSeq += config.hasCharacter(0) ? fmt::format("{}", config.getLetter(i)) : fmt::format("{} ", std::string(config.getAsFeature("FORM", i)));
        fmt::print(stdout, "{}\n{}\n\n", config.hasCharacter(0) ? curSeq : util::strip(curSeq), util::join(" ", transitionsIndexes));
        curOutputSeqSize = 0;
        curInputSeqSize = 0;
        curSeqStartIndex = curInputIndex;
      }
    
      fmt::print(stderr, "Longest output sequence : {}\n", maxOutputSeqSize);
      fmt::print(stderr, "Longest input sequence : {}\n", maxInputSeqSize);
    }