Skip to content
Snippets Groups Projects
Select Git revision
  • 1c4eea327dcbe753cf65df30b7c3cbe6e56441dd
  • master default
  • object
  • develop protected
  • private_algos
  • cuisine
  • SMOTE
  • revert-76c4cca5
  • archive protected
  • no_graphviz
  • 0.0.2
  • 0.0.1
12 results

LateFusion.py

Blame
  • Decoder.cpp 11.02 KiB
    #include "Decoder.hpp"
    #include "util.hpp"
    #include "Error.hpp"
    #include "ActionBank.hpp"
    #include <chrono>
    
    Decoder::Decoder(TransitionMachine & tm, Config & config)
    : tm(tm), config(config)
    {
    }
    
    struct EndOfDecode : public std::exception
    {
      const char * what() const throw()
      {
        return "End of Decode";
      }
    };
    
    struct NoMoreActions : public std::exception
    {
      const char * what() const throw()
      {
        return "No More Actions";
      }
    };
    
    void checkAndRecordError(Config & config, Classifier * classifier, Classifier::WeightedActions & weightedActions, Action * action, Errors & errors)
    {
        if (classifier->needsTrain() && ProgramParameters::errorAnalysis && (classifier->name == ProgramParameters::classifierName || ProgramParameters::classifierName.empty()))
        {
          auto zeroCostActions = classifier->getZeroCostActions(config);
          if (zeroCostActions.empty())
          {
            fprintf(stderr, "ERROR (%s) : could not find zero cost action for classifier \'%s\'. Aborting.\n", ERRINFO, classifier->name.c_str());
            config.printForDebug(stderr);
            for (auto & a : weightedActions)
            {
              fprintf(stderr, "%s : ", a.second.second.c_str());
              Oracle::explainCostOfAction(stderr, config, a.second.second);
            }
            exit(1);
          }
          std::string oAction = zeroCostActions[0];
          for (auto & s : zeroCostActions)
            if (action->name == s)
              oAction = s;
          int actionCost = classifier->getActionCost(config, action->name);
          int linkLengthPrediction = ActionBank::getLinkLength(config, action->name);
          int linkLengthGold = ActionBank::getLinkLength(config, oAction);
          errors.add({action->name, oAction, weightedActions, actionCost, linkLengthPrediction, linkLengthGold});
        }
    }
    
    void printAdvancement(Config & config, float currentSpeed)
    {
      if (ProgramParameters::interactive)
      {
        int totalSize = config.tapes[0].hyp.size();
        int steps = config.head;
        if (steps && (steps % 200 == 0 || totalSize-steps < 200))
          fprintf(stderr, "Decode : %.2f%%  speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str());
      }
    }
    
    void printDebugInfos(FILE * output, Config & config, TransitionMachine & tm, Classifier::WeightedActions & weightedActions)
    {
        if (ProgramParameters::debug)
        {
          TransitionMachine::State * currentState = tm.getCurrentState();
    
          config.printForDebug(output);
          fprintf(output, "State : \'%s\'\n", currentState->name.c_str());
    
          Classifier::printWeightedActions(output, weightedActions);
          fprintf(output, "\n");
        }
    }
    
    std::pair<float,std::string> getClassifierAction(Config & config, Classifier::WeightedActions & weightedActions, Classifier * classifier, unsigned int index)
    {
        std::string & predictedAction = weightedActions[0].second.second;
        float proba = weightedActions[0].second.first;
        Action * action = classifier->getAction(predictedAction);
    
        unsigned int nbValidActions = 0;
        for(unsigned int i = 0; i < weightedActions.size(); i++)
        {
          predictedAction = weightedActions[i].second.second;
          proba = weightedActions[i].second.first;
          action = classifier->getAction(predictedAction);
    
          if(weightedActions[i].first)
          {
            nbValidActions++;
            if (nbValidActions-1 == index)
              break;
          }
        }
    
        if(!action->appliable(config) || nbValidActions-1 != index)
        {
          // First case the analysis is finished but without an empty stack
          if (config.head == (int)config.tapes[0].ref.size()-1)
          {
            while (!config.stackEmpty())
              config.stackPop();
            throw EndOfDecode();
          }
          else if (nbValidActions-1 != index)
          {
            throw NoMoreActions();
          }
          else
          {
            fprintf(stderr, "ERROR (%s) : action \'%s\' is not appliable. Aborting\n", ERRINFO, predictedAction.c_str());
            exit(1);
          }
        }
    
        return {proba, predictedAction};
    }
    
    void computeSpeed(std::chrono::time_point<std::chrono::system_clock> & pastTime, int & nbActions, int & nbActionsCutoff, float & currentSpeed)
    {
      if (nbActions >= nbActionsCutoff)
      {
          auto actualTime = std::chrono::high_resolution_clock::now();
    
          double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
          currentSpeed = nbActions / seconds;
    
          pastTime = actualTime;
    
          nbActions = 0;
      }
    }
    
    void computeAndPrintSequenceEntropy(Config & config, bool & justFlipped, Errors & errors, float & entropyAccumulator, int & nbActionsInSequence)
    {
      if (ProgramParameters::printEntropy || ProgramParameters::errorAnalysis)
      {
        if (config.head >= 1 && config.getTape(ProgramParameters::sequenceDelimiterTape)[config.head-1] != ProgramParameters::sequenceDelimiter)
          justFlipped = false;
    
        if ((config.head >= 1 && config.getTape(ProgramParameters::sequenceDelimiterTape)[config.head-1] == ProgramParameters::sequenceDelimiter && !justFlipped))
        {
          justFlipped = true;
          errors.newSequence();
          entropyAccumulator /= nbActionsInSequence;
          nbActionsInSequence = 0;
          if (ProgramParameters::printEntropy)
            fprintf(stderr, "Entropy : %.2f\n", entropyAccumulator);
          entropyAccumulator = 0.0;
        }
      }
    }
    
    void computeAndRecordEntropy(Config & config, Classifier::WeightedActions & weightedActions, float & entropyAccumulator)
    {
      float entropy = Classifier::computeEntropy(weightedActions);
      config.addToEntropyHistory(entropy);
      entropyAccumulator += entropy;
    }
    
    void applyActionAndTakeTransition(TransitionMachine & tm, Action * action, Config & config)
    {
        TransitionMachine::Transition * transition = tm.getTransition(action->name);
        action->setInfos(transition->headMvt, tm.getCurrentState()->name);
        action->apply(config);
        tm.takeTransition(transition);
    }
    
    struct BeamNode
    {
      Classifier::WeightedActions weightedActions;
      double totalEntropy;
      TransitionMachine tm;
      Config config;
      Action * action;
    
      BeamNode(TransitionMachine & tm, Config & config) : tm(tm), config(config)
      {
        totalEntropy = 0.0;
      }
      BeamNode(BeamNode & other, Action * action, float proba) : tm(other.tm), config(other.config)
      {
        totalEntropy = other.totalEntropy + proba;
        this->action = action;
      }
    };
    
    void Decoder::decode()
    {
      if (ProgramParameters::beamSize > 1)
        decodeBeam();
      else
        decodeNoBeam();
    }
    
    void Decoder::decodeNoBeam()
    {
      float entropyAccumulator = 0.0;
      int nbActionsInSequence = 0;
      bool justFlipped = false;
      Errors errors;
      errors.newSequence();
      int nbActions = 0;
      int nbActionsCutoff = 200;
      float currentSpeed = 0.0;
      auto pastTime = std::chrono::high_resolution_clock::now();
    
      while (!config.isFinal())
      {
        TransitionMachine::State * currentState = tm.getCurrentState();
        Classifier * classifier = currentState->classifier;
        config.setCurrentStateName(&currentState->name);
        Dict::currentClassifierName = classifier->name;
    
        auto weightedActions = classifier->weightActions(config);
    
        printAdvancement(config, currentSpeed);
        printDebugInfos(stderr, config, tm, weightedActions);
    
        std::pair<float,std::string> predictedAction;
        try {predictedAction = getClassifierAction(config, weightedActions, classifier, 0);}
        catch(EndOfDecode &) {continue;}
        catch(NoMoreActions &) {continue;};
    
        Action * action = classifier->getAction(predictedAction.second);
    
        checkAndRecordError(config, classifier, weightedActions, action, errors);
    
        applyActionAndTakeTransition(tm, action, config);
    
        nbActionsInSequence++;
        nbActions++;
        computeSpeed(pastTime, nbActions, nbActionsCutoff, currentSpeed);
        computeAndRecordEntropy(config, weightedActions, entropyAccumulator);
        computeAndPrintSequenceEntropy(config, justFlipped, errors, entropyAccumulator, nbActionsInSequence);
      }
    
      if (ProgramParameters::errorAnalysis)
        errors.printStats();
      else
        config.printAsOutput(stdout);
    
      fprintf(stderr, "                                                     \n");
    }
    
    void Decoder::decodeBeam()
    {
      float entropyAccumulator = 0.0;
      int nbActionsInSequence = 0;
      bool justFlipped = false;
      Errors errors;
      errors.newSequence();
      int nbActions = 0;
      int nbActionsCutoff = 200;
      float currentSpeed = 0.0;
      auto pastTime = std::chrono::high_resolution_clock::now();
    
      std::vector< std::shared_ptr<BeamNode> > beam;
      std::vector< std::shared_ptr<BeamNode> > otherBeam;
      beam.emplace_back(new BeamNode(tm, config));
    
      auto sortBeam = [&beam]()
      {
        std::sort(beam.begin(), beam.end(), [](std::shared_ptr<BeamNode> a, std::shared_ptr<BeamNode> b)
            {
              return a->totalEntropy > b->totalEntropy;
            });
      };
    
      while (!beam[0]->config.isFinal())
      {
        otherBeam.clear();
    
        bool mustContinue = false;
        for (auto & node : beam)
        {
          auto & tm = node->tm;
          auto & config = node->config;
          TransitionMachine::State * currentState = tm.getCurrentState();
          Classifier * classifier = currentState->classifier;
          config.setCurrentStateName(&currentState->name);
          Dict::currentClassifierName = classifier->name;
    
          node->weightedActions = classifier->weightActions(config);
    
          printAdvancement(config, currentSpeed);
          printDebugInfos(stderr, config, tm, node->weightedActions);
    
          unsigned int nbActionsMax = std::min(std::max(classifier->getNbActions(),(unsigned int)1),(unsigned int)ProgramParameters::nbChilds);
          for (unsigned int actionIndex = 0; actionIndex < nbActionsMax; actionIndex++)
          {
            std::pair<float,std::string> predictedAction;
            try {predictedAction = getClassifierAction(config, node->weightedActions, classifier, actionIndex);}
            catch(EndOfDecode &) {mustContinue = true; break;}
            catch(NoMoreActions &) {break;};
            otherBeam.emplace_back(new BeamNode(*node.get(),classifier->getAction(predictedAction.second), predictedAction.first));
          }
    
          if (mustContinue)
            break;
        }
    
        if (mustContinue)
          continue;
    
        beam = otherBeam;
        sortBeam();
        beam.resize(std::min((int)beam.size(), ProgramParameters::beamSize));
    
        for (auto & node : beam)
        {
          auto & tm = node->tm;
          auto & config = node->config;
          TransitionMachine::State * currentState = tm.getCurrentState();
          Classifier * classifier = currentState->classifier;
          config.setCurrentStateName(&currentState->name);
          Dict::currentClassifierName = classifier->name;
    
          if (node.get() == beam.begin()->get())
          {
            checkAndRecordError(config, classifier, node->weightedActions, node->action, errors);
          }
    
          applyActionAndTakeTransition(tm, node->action, config);
    
          if (node.get() == beam.begin()->get())
          {
            nbActionsInSequence++;
            nbActions++;
            computeSpeed(pastTime, nbActions, nbActionsCutoff, currentSpeed);
            computeAndRecordEntropy(config, node->weightedActions, entropyAccumulator);
            computeAndPrintSequenceEntropy(config, justFlipped, errors, entropyAccumulator, nbActionsInSequence);
          }
        }
      }
    
      if (ProgramParameters::errorAnalysis)
        errors.printStats();
      else
        beam[0]->config.printAsOutput(stdout);
    
      fprintf(stderr, "                                                     \n");
    }