Skip to content
Snippets Groups Projects
Select Git revision
  • 4ad6e72501a1cc68436d5e5b0eb0e48556ea43d8
  • master default protected
  • correlation
  • 24-non-negative-omp
  • 15-integration-sota
  • 20-coherence-des-arbres-de-predictions
  • 19-add-some-tests
  • 13-visualization
  • 17-adding-new-datasets
  • 12-experiment-pipeline
  • 14-correction-of-multiclass-classif
  • archive/10-gridsearching-of-the-base-forest
  • archive/farah_notation_and_related_work
  • archive/wip_clean_scripts
  • archive/4-implement-omp_forest_classifier
  • archive/5-add-plots-2
  • archive/Leo_Add_first_notebook
17 results

compute_results.py

Blame
  • Decoder.cpp 13.96 KiB
    #include "Decoder.hpp"
    #include "util.hpp"
    #include "Error.hpp"
    #include "ActionBank.hpp"
    #include "ProgramOutput.hpp"
    #include <chrono>
    
    Decoder::Decoder(TransitionMachine & tm, Config & config)
    : tm(tm), config(config)
    {
    }
    
    struct EndOfDecode : public std::exception
    {
      const char * what() const throw()
      {
        return "End of Decode";
      }
    };
    
    struct NoMoreActions : public std::exception
    {
      const char * what() const throw()
      {
        return "No More Actions";
      }
    };
    
    bool EOSisPredicted(Config & config)
    {
      return config.hasTape(ProgramParameters::sequenceDelimiterTape);
    }
    
    void checkAndRecordError(Config & config, Classifier * classifier, Classifier::WeightedActions & weightedActions, std::string & action, Errors & errors)
    {
        if (classifier->needsTrain() && ProgramParameters::errorAnalysis && (classifier->name == ProgramParameters::classifierName || ProgramParameters::classifierName.empty()))
        {
          auto zeroCostActions = classifier->getZeroCostActions(config);
          if (zeroCostActions.empty())
          {
            fprintf(stderr, "ERROR (%s) : could not find zero cost action for classifier \'%s\'. Aborting.\n", ERRINFO, classifier->name.c_str());
            config.printForDebug(stderr);
            for (auto & a : weightedActions)
            {
              fprintf(stderr, "%s : ", a.second.second.c_str());
              Oracle::explainCostOfAction(stderr, config, a.second.second);
            }
            exit(1);
          }
          std::string oAction = zeroCostActions[0];
          for (auto & s : zeroCostActions)
            if (action == s)
              oAction = s;
          int actionCost = classifier->getActionCost(config, action);
          int linkLengthPrediction = ActionBank::getLinkLength(config, action);
          int linkLengthGold = ActionBank::getLinkLength(config, oAction);
          errors.add({action, oAction, weightedActions, actionCost, linkLengthPrediction, linkLengthGold});
        }
    }
    
    void printAdvancement(Config & config, float currentSpeed, int nbActionsCutoff)
    {
      if (ProgramParameters::interactive)
      {
        int totalSize = ProgramParameters::tapeSize;
        int steps = config.getHead();
        if (steps && (steps % nbActionsCutoff == 0 || totalSize-steps < nbActionsCutoff))
          fprintf(stderr, "Decode : %.2f%%  speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str());
      }
    }
    
    void printDebugInfos(FILE * output, Config & config, TransitionMachine & tm, Classifier::WeightedActions & weightedActions)
    {
        if (ProgramParameters::debug)
        {
          config.printForDebug(output);
          fprintf(output, "State : \'%s\'\n", tm.getCurrentState().c_str());
    
          Classifier::printWeightedActions(output, weightedActions);
          fprintf(output, "\n");
        }
    }
    
    std::pair<float,std::string> getClassifierAction(Config & config, Classifier::WeightedActions & weightedActions, Classifier * classifier, unsigned int index)
    {
        std::string & predictedAction = weightedActions[0].second.second;
        float proba = weightedActions[0].second.first;
        Action * action = classifier->getAction(predictedAction);
    
        unsigned int nbValidActions = 0;
        for(unsigned int i = 0; i < weightedActions.size(); i++)
        {
          predictedAction = weightedActions[i].second.second;
          proba = weightedActions[i].second.first;
          action = classifier->getAction(predictedAction);
    
          if(weightedActions[i].first)
          {
            nbValidActions++;
            if (nbValidActions-1 == index)
              break;
          }
        }
    
        if(!action->appliable(config) || nbValidActions-1 != index)
        {
          // First case the analysis is finished but without an empty stack
          if (config.endOfTapes())
          {
            while (!config.stackEmpty())
              config.stackPop();
            throw EndOfDecode();
          }
          else if (nbValidActions-1 != index)
          {
            throw NoMoreActions();
          }
          else
          {
            fprintf(stderr, "ERROR (%s) : action \'%s\' is not appliable. Aborting\n", ERRINFO, predictedAction.c_str());
            exit(1);
          }
        }
    
        return {proba, predictedAction};
    }
    
    void computeSpeed(std::chrono::time_point<std::chrono::system_clock> & pastTime, int & nbActions, int & nbActionsCutoff, float & currentSpeed)
    {
      if (nbActions >= nbActionsCutoff)
      {
          auto actualTime = std::chrono::high_resolution_clock::now();
    
          double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
          currentSpeed = nbActions / seconds;
    
          pastTime = actualTime;
    
          nbActions = 0;
      }
    }
    
    void computeAndPrintSequenceEntropy(Config & config, bool & justFlipped, Errors & errors, float & entropyAccumulator, int & nbActionsInSequence)
    {
      if (!EOSisPredicted(config))
        return;
    
      if (config.getHead() >= 1 && config.getTape(ProgramParameters::sequenceDelimiterTape)[-1] != ProgramParameters::sequenceDelimiter)
        justFlipped = false; 
    
      if ((config.getHead() >= 1 && config.getTape(ProgramParameters::sequenceDelimiterTape)[-1] == ProgramParameters::sequenceDelimiter && !justFlipped))
        justFlipped = true;
    
      if (justFlipped && (ProgramParameters::printEntropy || ProgramParameters::errorAnalysis))
      {
        justFlipped = true;
        errors.newSequence();
        entropyAccumulator /= nbActionsInSequence;
        nbActionsInSequence = 0;
        if (ProgramParameters::printEntropy)
          fprintf(stderr, "Entropy : %.2f\n", entropyAccumulator);
        entropyAccumulator = 0.0;
      }
    }
    
    void computeAndRecordEntropy(Config & config, Classifier::WeightedActions & weightedActions, float & entropyAccumulator)
    {
      float entropy = Classifier::computeEntropy(weightedActions);
      config.addToEntropyHistory(entropy);
      entropyAccumulator += entropy;
    }
    
    void applyActionAndTakeTransition(TransitionMachine & tm, const std::string & actionName, Config & config)
    {
        Action * action = tm.getCurrentClassifier()->getAction(actionName);
        TransitionMachine::Transition * transition = tm.getTransition(actionName);
        action->setInfos(transition->headMvt, tm.getCurrentState());
        action->apply(config);
        tm.takeTransition(transition);
    }
    
    void Decoder::decode()
    {
      config.reset();
    
      if (ProgramParameters::beamSize > 1)
        decodeBeam();
      else
        decodeNoBeam();
    
      ProgramOutput::instance.print(stdout);
    }
    
    void Decoder::decodeNoBeam()
    {
      float entropyAccumulator = 0.0;
      int nbActionsInSequence = 0;
      bool justFlipped = false;
      Errors errors;
      errors.newSequence();
      int nbActions = 0;
      int nbActionsCutoff = 200;
      float currentSpeed = 0.0;
      auto pastTime = std::chrono::high_resolution_clock::now();
      FILE * outputFile = stdout;
    
      config.setOutputFile(outputFile);
    
      while (!config.isFinal())
      {
        config.setCurrentStateName(tm.getCurrentState());
        Dict::currentClassifierName = tm.getCurrentClassifier()->name;
    
        auto weightedActions = tm.getCurrentClassifier()->weightActions(config);
    
        printAdvancement(config, currentSpeed, nbActionsCutoff);
        printDebugInfos(stderr, config, tm, weightedActions);
    
        std::pair<float,std::string> predictedAction;
        try {predictedAction = getClassifierAction(config, weightedActions, tm.getCurrentClassifier(), 0);}
        catch(EndOfDecode &) {continue;}
        catch(NoMoreActions &) {continue;};
    
        checkAndRecordError(config, tm.getCurrentClassifier(), weightedActions, predictedAction.second, errors);
    
        applyActionAndTakeTransition(tm, predictedAction.second, config);
    
        nbActionsInSequence++;
        nbActions++;
        computeSpeed(pastTime, nbActions, nbActionsCutoff, currentSpeed);
        computeAndRecordEntropy(config, weightedActions, entropyAccumulator);
        computeAndPrintSequenceEntropy(config, justFlipped, errors, entropyAccumulator, nbActionsInSequence);
      }
    
      if (ProgramParameters::errorAnalysis)
        errors.printStats();
    
      config.printTheRest();
    
      if (ProgramParameters::interactive)
        fprintf(stderr, "                                                     \n");
    }
    
    struct BeamNode
    {
      Classifier::WeightedActions weightedActions;
      TransitionMachine tm;
      Config config;
      std::string action;
      int nbActions;
      bool justFlipped;
      int lastFlippedIndex;
    
      double getEntropy()
      {
        if (nbActions == 0)
          return 0.0;
    
        return config.getEntropy() / nbActions;
      }
      void setFlipped()
      {
        if (EOSisPredicted(config) && config.getHead() > lastFlippedIndex && config.getTape(ProgramParameters::sequenceDelimiterTape)[-1] == ProgramParameters::sequenceDelimiter)
        {
          justFlipped = true; 
          lastFlippedIndex = config.getHead();
        }
        else
        {
          justFlipped = false;
        }
      }
      BeamNode(TransitionMachine & tm, Config & config) : tm(tm), config(config)
      {
        justFlipped = false;
        nbActions = 0;
        lastFlippedIndex = 0;
        config.setOutputFile(nullptr);
        config.setEntropy(0.0);
      }
      BeamNode(BeamNode & other, const std::string & action, float proba) : tm(other.tm), config(other.config)
      {
        justFlipped = false;
        lastFlippedIndex = other.lastFlippedIndex;
        this->action = action;
        nbActions = other.nbActions + 1;
        config.setOutputFile(nullptr);
        config.addToEntropy(proba);
      }
    };
    
    void Decoder::decodeBeam()
    {
      float entropyAccumulator = 0.0;
      int nbActionsInSequence = 0;
      bool justFlipped = false;
      Errors errors;
      errors.newSequence();
      int nbActions = 0;
      int nbActionsCutoff = 200;
      float currentSpeed = 0.0;
      auto pastTime = std::chrono::high_resolution_clock::now();
    
      FILE * outputFile = stdout;
    
      std::vector< std::shared_ptr<BeamNode> > beam;
      std::vector< std::shared_ptr<BeamNode> > otherBeam;
      std::vector< std::shared_ptr<BeamNode> > justFlippedBeam;
      beam.emplace_back(new BeamNode(tm, config));
    
      auto sortBeam = [&beam]()
      {
        std::sort(beam.begin(), beam.end(), [](const std::shared_ptr<BeamNode> & a, const std::shared_ptr<BeamNode> & b)
            {
              return a->getEntropy() > b->getEntropy();
            });
      };
    
      auto printBeam = [](std::vector< std::shared_ptr<BeamNode> > & beam)
      {
        for (auto & node : beam)
        {
          node->config.printForDebug(stderr);
          fprintf(stderr, "action : %s\n", node->action.c_str());
          fprintf(stderr, "nbActions : %d\n", node->nbActions);
          fprintf(stderr, "justFlipped : %s\n", node->justFlipped ? "true" : "false");
          fprintf(stderr, "lastFlippedIndex : %d\n", node->lastFlippedIndex);
          fprintf(stderr, "--------------------------------------------------------------------------------\n");
        }
      };
    
      bool endOfDecode = false;
    
      while (endOfDecode == false)
      {
        otherBeam.clear();
    
        bool mustContinue = false;
        for (auto & node : beam)
        {
          node->config.setCurrentStateName(node->tm.getCurrentState());
          Dict::currentClassifierName = node->tm.getCurrentClassifier()->name;
    
          node->weightedActions = node->tm.getCurrentClassifier()->weightActions(node->config);
    
          printAdvancement(node->config, currentSpeed, nbActionsCutoff);
    
          unsigned int nbActionsMax = std::min(std::max(node->tm.getCurrentClassifier()->getNbActions(),(unsigned int)1),(unsigned int)ProgramParameters::nbChilds);
          for (unsigned int actionIndex = 0; actionIndex < nbActionsMax; actionIndex++)
          {
            std::pair<float,std::string> predictedAction;
            try {predictedAction = getClassifierAction(node->config, node->weightedActions, node->tm.getCurrentClassifier(), actionIndex);}
            catch(EndOfDecode &) {mustContinue = true; break;}
            catch(NoMoreActions &) {break;};
            otherBeam.emplace_back(new BeamNode(*node.get(), predictedAction.second, predictedAction.first));
          }
    
          if (mustContinue)
            break;
        }
    
        if (ProgramParameters::debug)
        {
          fprintf(stderr, "################################# Beam before sort #################################\n");
          printBeam(otherBeam);
          fprintf(stderr, "####################################################################################\n");
        }
    
        beam = otherBeam;
        sortBeam();
        beam.resize(std::min((int)beam.size(), ProgramParameters::beamSize));
        if (beam.empty())
        {
          fprintf(stderr, "ERROR (%s) : beam is empty. Aborting.\n", ERRINFO);
          exit(1);
        }
    
        if (ProgramParameters::debug)
        {
          fprintf(stderr, "################################# Beam after sort #################################\n");
          printBeam(beam);
          fprintf(stderr, "###################################################################################\n");
        }
    
        for (auto & node : beam)
          node->config.setOutputFile(outputFile);
    
        for (auto & node : beam)
        {
          config.setCurrentStateName(node->tm.getCurrentState());
          Dict::currentClassifierName = node->tm.getCurrentClassifier()->name;
    
          if (node.get() == beam.begin()->get())
            checkAndRecordError(node->config, node->tm.getCurrentClassifier(), node->weightedActions, node->action, errors);
    
          applyActionAndTakeTransition(node->tm, node->action, node->config);
    
          if (node.get() == beam.begin()->get())
          {
            nbActionsInSequence++;
            nbActions++;
            computeSpeed(pastTime, nbActions, nbActionsCutoff, currentSpeed);
            computeAndRecordEntropy(node->config, node->weightedActions, entropyAccumulator);
            computeAndPrintSequenceEntropy(node->config, justFlipped, errors, entropyAccumulator, nbActionsInSequence);
          }
          node->setFlipped();
        }
    
        for (unsigned int i = 0; i < beam.size(); i++)
        {
          if (beam[i]->justFlipped)
          {
            justFlippedBeam.push_back(beam[i]);
            beam[i] = beam[beam.size()-1];
            beam.pop_back();
            i--;
          }
        }
    
        if ((int)justFlippedBeam.size() >= ProgramParameters::beamSize || (justFlippedBeam.size() && mustContinue))
        {
          if (mustContinue || justFlippedBeam[0]->config.endOfTapes())
            endOfDecode = true;
          beam = justFlippedBeam;
          justFlippedBeam.clear();
          sortBeam();
          beam.resize(1);
          beam[0]->config.setEntropy(0.0);
          beam[0]->nbActions = 0;
        }
    
        if (!EOSisPredicted(beam[0]->config) && beam[0]->config.endOfTapes())
          endOfDecode = true;
      }
    
      if (ProgramParameters::errorAnalysis)
        errors.printStats();
    
      for (auto node : beam)
      {
        node->config.setOutputFile(outputFile);
        node->config.printTheRest();
      }
    
      if (ProgramParameters::interactive)
        fprintf(stderr, "                                                     \n");
    }