Skip to content
Snippets Groups Projects
Select Git revision
  • c73a43d3ccdd0e3b9d04a18909928d125e08251d
  • master default protected
  • loss
  • producer
4 results

install.md

Blame
  • Trainer.cpp 18.28 KiB
    #include "Trainer.hpp"
    #include "util.hpp"
    
    Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config)
    : tm(tm), trainBD(bd), trainConfig(config)
    {
      this->devBD = nullptr;
      this->devConfig = nullptr;
    
      nbSteps = 0;
      nbActions = 0;
      nbActionsCutoff = 2*ProgramParameters::batchSize;
      currentSpeed = 0.0;
      pastTime = std::chrono::high_resolution_clock::now();
    }
    
    Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig) : tm(tm), trainBD(bd), trainConfig(config), devBD(devBD), devConfig(devConfig)
    {
      nbSteps = 0;
      nbActions = 0;
      nbActionsCutoff = 2*ProgramParameters::batchSize;
      currentSpeed = 0.0;
      pastTime = std::chrono::high_resolution_clock::now();
    }
    
    void Trainer::setDebugValue()
    {
      if (!ProgramParameters::randomDebug)
        return;
    
      ProgramParameters::debug = choiceWithProbability(ProgramParameters::randomDebugProbability);
    
      if (!ProgramParameters::debug)
        return;
    
      if (ProgramParameters::interactive)
        fprintf(stderr, "                            \r");
    
      fprintf(stderr, "\n");
    
      if (ProgramParameters::printTime)
        fprintf(stderr, "[%s] :\n", getTime().c_str());
    }
    
    void Trainer::computeScoreOnDev()
    {
      if (!devConfig)
        return;
    
      tm.reset();
      devConfig->reset();
    
      if (ProgramParameters::debug)
        fprintf(stderr, "Computing score on dev set\n");
    
      int nbActionsInSequence = 0;
      float entropyAccumulator = 0.0;
      bool justFlipped = false;
      int nbActions = 0;
      int nbActionsCutoff = 2*ProgramParameters::batchSize;
      currentSpeed = 0.0;
      auto pastTime = std::chrono::high_resolution_clock::now();
      std::vector<float> entropies;
    
      while (!devConfig->isFinal())
      {
        setDebugValue();
        devConfig->setCurrentStateName(tm.getCurrentClassifier()->name);
        Dict::currentClassifierName = tm.getCurrentClassifier()->name;
        tm.getCurrentClassifier()->initClassifier(*devConfig);
    
        if(!tm.getCurrentClassifier()->needsTrain())
        {
          int neededActionIndex = tm.getCurrentClassifier()->getOracleActionIndex(*devConfig);
          std::string neededActionName = tm.getCurrentClassifier()->getActionName(neededActionIndex);
          Action * action = tm.getCurrentClassifier()->getAction(neededActionName);
          TransitionMachine::Transition * transition = tm.getTransition(neededActionName);
          action->setInfos(tm.getCurrentClassifier()->name);
    
          action->apply(*devConfig);
          tm.takeTransition(transition);
        }
        else
        {
          // Print current iter advancement in percentage
          if (ProgramParameters::interactive) 
          {
            int totalSize = ProgramParameters::devTapeSize;
            int steps = devConfig->getHead();
            if (steps && (steps % nbActionsCutoff == 0 || totalSize-steps < nbActionsCutoff))
            {
              fprintf(stderr, "                                                      \r");
              fprintf(stderr, "Eval on dev : %.2f%% speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str());
            }
          }
    
          auto weightedActions = tm.getCurrentClassifier()->weightActions(*devConfig);
          std::string pAction = "";
          std::string oAction = "";
    
          for (auto & it : weightedActions)
            if (it.first)
            {
              if (pAction.empty())
                pAction = it.second.second;
    
               if (tm.getCurrentClassifier()->getActionCost(*devConfig, it.second.second) == 0)
               {
                 oAction = it.second.second;
                 break;
               }
            }
    
          if (ProgramParameters::devLoss)
          {
            float loss = tm.getCurrentClassifier()->getLoss(*devConfig, tm.getCurrentClassifier()->getActionIndex(oAction));
            TI.addDevLoss(tm.getCurrentClassifier()->name, loss);
          }
    
          std::string actionName;
    
          if (ProgramParameters::devEvalOnGold)
            actionName = oAction;
          else
            actionName = pAction;
    
          Action * action = tm.getCurrentClassifier()->getAction(actionName);
    
          if (ProgramParameters::debug)
          {
            fprintf(stderr, "Speed : %s actions/s\n", int2humanStr((int)currentSpeed).c_str());
            devConfig->printForDebug(stderr);
            fprintf(stderr, "pAction=<%s> action=<%s>\n", pAction.c_str(), actionName.c_str());
          }
    
          TransitionMachine::Transition * transition = tm.getTransition(actionName);
          action->setInfos(tm.getCurrentClassifier()->name);
          devConfig->addToActionsHistory(tm.getCurrentClassifier()->name, actionName, tm.getCurrentClassifier()->getActionCost(*devConfig, actionName));
    
          action->apply(*devConfig);
          tm.takeTransition(transition);
    
          nbActions++;
    
          if (nbActions >= nbActionsCutoff)
          {
            auto actualTime = std::chrono::high_resolution_clock::now();
    
            double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
            currentSpeed = nbActions / seconds;
    
            pastTime = actualTime;
    
            nbActions = 0;
          }
    
          float entropy = Classifier::computeEntropy(weightedActions);
          devConfig->addToEntropyHistory(entropy);
    
          if (ProgramParameters::printEntropy)
          {
            nbActionsInSequence++;
    
            entropyAccumulator += entropy;
    
            if (devConfig->getHead() >= 1 && devConfig->getTape(ProgramParameters::sequenceDelimiterTape)[devConfig->getHead()-1] != ProgramParameters::sequenceDelimiter)
            justFlipped = false;
    
            if ((devConfig->getHead() >= 1 && devConfig->getTape(ProgramParameters::sequenceDelimiterTape)[devConfig->getHead()-1] == ProgramParameters::sequenceDelimiter && !justFlipped))
            {
              justFlipped = true;
              entropyAccumulator /= nbActionsInSequence;
              nbActionsInSequence = 0;
              entropies.emplace_back(entropyAccumulator);
              entropyAccumulator = 0.0;
            }
          }
    
        }
      }
    
      if (ProgramParameters::debug)
        fprintf(stderr, "Dev Config is final\n");
    
      TI.computeDevScores(*devConfig);
    
      if (ProgramParameters::debug)
        fprintf(stderr, "End of %s\n", __func__);
    }
    
    void Trainer::resetAndShuffle()
    {
      tm.reset();
      trainConfig.reset();
    
      if(ProgramParameters::shuffleExamples)
        trainConfig.shuffle(ProgramParameters::sequenceDelimiterTape, ProgramParameters::sequenceDelimiter);
    }
    
    void Trainer::doStepNoTrain()
    {
      int neededActionIndex = tm.getCurrentClassifier()->getOracleActionIndex(trainConfig);
      std::string neededActionName = tm.getCurrentClassifier()->getActionName(neededActionIndex);
      if (ProgramParameters::debug)
      {
        fprintf(stderr, "Speed : %s actions/s\n", int2humanStr((int)currentSpeed).c_str());
        trainConfig.printForDebug(stderr);
        fprintf(stderr, "action=<%s>\n", neededActionName.c_str());
      }
    
      Action * action = tm.getCurrentClassifier()->getAction(neededActionName);
      TransitionMachine::Transition * transition = tm.getTransition(neededActionName);
      action->setInfos(tm.getCurrentClassifier()->name);
      trainConfig.addToActionsHistory(tm.getCurrentClassifier()->name, action->name, tm.getCurrentClassifier()->getActionCost(trainConfig, action->name));
    
      action->apply(trainConfig);
      tm.takeTransition(transition);
    }
    
    void Trainer::doStepTrain()
    {
      if (!TI.isTopologyPrinted(tm.getCurrentClassifier()->name))
      {
        TI.setTopologyPrinted(tm.getCurrentClassifier()->name);
        tm.getCurrentClassifier()->printTopology(stderr);
      }
    
      // Print current iter advancement in percentage
      if (ProgramParameters::interactive)
      {
        int totalSize = ProgramParameters::iterationSize == -1 ? ProgramParameters::tapeSize : ProgramParameters::iterationSize;
        int steps = ProgramParameters::iterationSize == -1 ? trainConfig.getHead() : nbSteps;
        if (steps % nbActionsCutoff == 0 || totalSize-steps < nbActionsCutoff)
        {
          fprintf(stderr, "                                                      \r");
          fprintf(stderr, "Current Iteration : %.2f%% speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str());
        }
      }
    
      std::string pAction = "";
      std::string oAction = "";
      bool pActionIsZeroCost = false;
    
      std::string actionName = "";
      float loss = 0.0;
    
      Classifier::WeightedActions weightedActions;
      if (tm.getCurrentClassifier()->name.rfind("Error_", 0) != 0)
      {
        if (!ProgramParameters::featureExtraction)
        {
          weightedActions = tm.getCurrentClassifier()->weightActions(trainConfig);
      
          for (auto & it : weightedActions)
            if (it.first)
            {
              if (pAction == "")
                pAction = it.second.second;
      
              if (tm.getCurrentClassifier()->getActionCost(trainConfig, it.second.second) == 0)
              {
                oAction = it.second.second;
                break;
              }
            }
      
          if (pAction == oAction)
            pActionIsZeroCost = true;
        }
        else
        {
          oAction = tm.getCurrentClassifier()->getZeroCostActions(trainConfig)[0];
        }
      
        if (oAction.empty())
          oAction = tm.getCurrentClassifier()->getDefaultAction();
      
        if (oAction.empty())
        {
          if (trainConfig.endOfTapes())
          {
            while (!trainConfig.stackEmpty())
              trainConfig.stackPop();
            throw EndOfIteration();
          }
      
          fprintf(stderr, "ERROR (%s) : Unable to find any zero cost action. Aborting.\n", ERRINFO);
          fprintf(stderr, "State : %s\n", tm.getCurrentState().c_str());
          trainConfig.printForDebug(stderr);
          tm.getCurrentClassifier()->explainCostOfActions(stderr, trainConfig);
          exit(1);
        }
    
        if (!ProgramParameters::featureExtraction)
          loss = tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex(oAction));
      
        TI.addTrainLoss(tm.getCurrentClassifier()->name, loss);
      
        int k = ProgramParameters::dynamicEpoch;
      
        if (ProgramParameters::featureExtraction)
        {
          auto features = tm.getCurrentClassifier()->getFeatureModel()->getFeatureDescription(trainConfig).featureValues();
          fprintf(stdout, "%s\t%s\t%s\n", tm.getCurrentClassifier()->getFeatureModel()->filename.c_str(), oAction.c_str(), features.c_str());
        }
      
        if (TI.getEpoch() >= k && choiceWithProbability(ProgramParameters::dynamicProbability))
        {
          actionName = pAction;
          TI.lastActionWasPredicted[trainConfig.getCurrentStateName()] = true;
        }
        else
        {
          if (pActionIsZeroCost)
          {
            actionName = pAction;
            TI.lastActionWasPredicted[trainConfig.getCurrentStateName()] = true;
          }
          else
          {
            actionName = oAction;
            TI.lastActionWasPredicted[trainConfig.getCurrentStateName()] = false;
          }
    
        }
      
        if (ProgramParameters::debug)
        {
          fprintf(stderr, "Speed : %s actions/s\n", int2humanStr((int)currentSpeed).c_str());
          trainConfig.printForDebug(stderr);
          tm.getCurrentClassifier()->printWeightedActions(stderr, weightedActions, 10);
          fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str());
        }
    
      }
      else
      {
        if (!ProgramParameters::featureExtraction)
        {
          weightedActions = tm.getCurrentClassifier()->weightActions(trainConfig);
      
          for (auto & it : weightedActions)
            if (it.first)
            {
              pAction = it.second.second;
              break;
            }
    
          auto zeroCosts = tm.getCurrentClassifier()->getZeroCostActions(trainConfig);
          oAction = zeroCosts[rand() % zeroCosts.size()];
        }
        else
        {
          oAction = tm.getCurrentClassifier()->getZeroCostActions(trainConfig)[0];
        }
      
        if (oAction.empty())
          oAction = tm.getCurrentClassifier()->getDefaultAction();
      
        if (oAction.empty())
        {
          if (trainConfig.endOfTapes())
          {
            while (!trainConfig.stackEmpty())
              trainConfig.stackPop();
            throw EndOfIteration();
          }
        }
    
        actionName = pAction;
        if (TI.getEpoch() < ProgramParameters::dynamicEpoch)
          actionName = oAction;
        else if (actionName == "EPSILON")
          actionName = oAction;
    
        char buffer[1024];
        if (sscanf(trainConfig.getCurrentStateName().c_str(), "Error_%s", buffer) != 1)
        {
          fprintf(stderr, "ERROR (%s) : unexpected classifier name \'%s\'. Aborting.\n", ERRINFO, trainConfig.getCurrentStateName().c_str());
          exit(1);
        }
        std::string normalStateName(buffer);
    
        auto & normalHistory = trainConfig.getActionsHistory(normalStateName);
    
        // If a BACK just happened
        if (normalHistory.size() > 1 && trainConfig.getCurrentStateHistory().size() > 0 && split(trainConfig.getCurrentStateHistory().top(), ' ')[0] == "BACK" && TI.getEpoch() >= ProgramParameters::dynamicEpoch)
        {
          auto & lastAction = trainConfig.lastUndoneAction;
          auto & newAction = normalHistory[normalHistory.size()-1];
          auto & lastActionName = lastAction.first;
          auto & newActionName = newAction.first;
          auto lastCost = lastAction.second;
          auto newCost = newAction.second;
          
          if (ProgramParameters::debug)
          {
            fprintf(stderr, "<%s>(%d) -> <%s>(%d)\n", lastActionName.c_str(), lastCost, newActionName.c_str(), newCost);
          }
    
          if (TI.lastActionWasPredicted[normalStateName])
          {
            std::string updateInfos;
    
            if (newCost >= lastCost)
            {
              if (true)
              {
                loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], -(tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top())+1));
              }
              else
              {
                int nbActions = tm.getCurrentClassifier()->getNbActions();
                int backIndex = tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top());
                float value = 1.0 / (nbActions-1);
                std::vector<float> goldOutput(nbActions, value);
                goldOutput[backIndex] = 0.0;
    
                loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], goldOutput);
              }
    
              updateInfos = "predicted : <"+trainConfig.getCurrentStateHistory().top()+">, bad decision";
            }
            else
            {
              if (true)
              {
                loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top()));
              }
              else
              {
                int nbActions = tm.getCurrentClassifier()->getNbActions();
                int backIndex = tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top());
                std::vector<float> goldOutput(nbActions, 0.0);
                goldOutput[backIndex] = 1.0;
    
                loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], goldOutput);
              }
    
              updateInfos = "predicted : <"+trainConfig.getCurrentStateHistory().top()+">, good decision";
            }
    
            if (ProgramParameters::debug)
              fprintf(stderr, "Updating neural network \'%s\' : %s\n", tm.getCurrentClassifier()->name.c_str(), updateInfos.c_str());
    
            TI.addTrainLoss(tm.getCurrentClassifier()->name, loss);
          }
    
        }
      
        if (ProgramParameters::debug)
        {
          fprintf(stderr, "Speed : %s actions/s\n", int2humanStr((int)currentSpeed).c_str());
          trainConfig.printForDebug(stderr);
          tm.getCurrentClassifier()->printWeightedActions(stderr, weightedActions, 10);
          fprintf(stderr, "pAction=<%s> oAction=<%s> action=<%s>\n", pAction.c_str(), oAction.c_str(), actionName.c_str());
        }
    
        if (actionName != "EPSILON")
        {
          pendingFD[tm.getCurrentClassifier()->name] = tm.getCurrentClassifier()->getFeatureDescription(trainConfig);
        }
      }
    
      Action * action = tm.getCurrentClassifier()->getAction(actionName);
      TransitionMachine::Transition * transition = tm.getTransition(actionName);
      action->setInfos(tm.getCurrentClassifier()->name);
    
      trainConfig.addToActionsHistory(tm.getCurrentClassifier()->name, actionName, tm.getCurrentClassifier()->getActionCost(trainConfig, actionName));
    
      action->apply(trainConfig);
      tm.takeTransition(transition);
    
      nbActions++;
    
      if (nbActions >= nbActionsCutoff)
      {
        auto actualTime = std::chrono::high_resolution_clock::now();
    
        double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
        currentSpeed = nbActions / seconds;
    
        pastTime = actualTime;
    
        nbActions = 0;
      }
    
      float entropy = Classifier::computeEntropy(weightedActions);
      trainConfig.addToEntropyHistory(entropy);
    }
    
    void Trainer::endOfIteration()
    {
      auto classifiers = tm.getClassifiers();
      for (auto * cla : classifiers)
        if (cla->needsTrain())
          cla->endOfIteration();
    }
    
    void Trainer::prepareNextEpoch()
    {
      endOfIteration();
    
      printScoresAndSave(stderr);
      nbSteps = 0;
      TI.nextEpoch();
    
      endOfIteration();
    
      if (TI.getEpoch() > ProgramParameters::nbIter)
        throw EndOfTraining();
    }
    
    void Trainer::train()
    {
      Dict::createFiles(ProgramParameters::expPath, "");
    
      fprintf(stderr, "%sTraining of \'%s\' :\n", 
        ProgramParameters::printTime ? ("["+getTime()+"] ").c_str() : "",
        tm.name.c_str());
    
      while (TI.getEpoch() <= ProgramParameters::nbIter)
      {
        resetAndShuffle();
        while (!trainConfig.isFinal())
        {
          setDebugValue();
          trainConfig.setCurrentStateName(tm.getCurrentClassifier()->name);
          Dict::currentClassifierName = tm.getCurrentClassifier()->name;
          tm.getCurrentClassifier()->initClassifier(trainConfig);
    
          if(!tm.getCurrentClassifier()->needsTrain())
            doStepNoTrain();
          else
            try {doStepTrain();}
            catch (EndOfIteration &) {break;}
    
          nbSteps++;
    
          if (ProgramParameters::iterationSize != -1 && nbSteps >= ProgramParameters::iterationSize)
            try {prepareNextEpoch();}
            catch (EndOfTraining &) {break;}
        }
    
        if (ProgramParameters::debug)
          fprintf(stderr, "Config is final\n");
    
        if (ProgramParameters::iterationSize == -1)
          try {prepareNextEpoch();}
          catch (EndOfTraining &) {break;}
    
        if (ProgramParameters::debug)
          fprintf(stderr, "End of epoch\n");
      }
    }
    
    void Trainer::printScoresAndSave(FILE * output)
    {
      TI.computeTrainScores(trainConfig);
      computeScoreOnDev();
      TI.computeMustSaves();
    
      auto classifiers = tm.getClassifiers();
      for (auto * cla : classifiers)
        if (TI.mustSave(cla->name))
        {
          if (ProgramParameters::debug)
            fprintf(stderr, "Saving %s...", cla->name.c_str());
          cla->save(ProgramParameters::expPath + cla->name + ".model");
          Dict::saveDicts(ProgramParameters::expPath, cla->name);
          if (ProgramParameters::debug)
            fprintf(stderr, "Done !\n");
        }
    
      TI.printScores(output);
    
       if (ProgramParameters::debug)
        fprintf(stderr, "End of %s\n", __func__); 
    }