Skip to content
Snippets Groups Projects
Select Git revision
  • 1b5e289896702204467d5e26ed02637c181d3f0d
  • master default protected
  • ci39
  • ci39-python12
  • py39
  • issue#14
  • endianness
  • bugs_i686
  • bug_test_instfreqplot_arm64
  • bug_test_tfplot
  • gitlab-ci
  • debian
  • v1.0.17
  • v1.0.16
  • v1.0.15
  • v1.0.14
  • v1.0.13
  • v1.0.12
  • v1.0.9
  • v1.0.8
  • v1.0.7
  • v1.0.6
  • v1.0.0
23 results

setup.py

Blame
  • Decoder.cpp 7.39 KiB
    #include "Decoder.hpp"
    #include "util.hpp"
    #include "Error.hpp"
    #include "ActionBank.hpp"
    #include <chrono>
    
    Decoder::Decoder(TransitionMachine & tm, Config & config)
    : tm(tm), config(config)
    {
    }
    
    struct EndOfDecode : public std::exception
    {
      const char * what() const throw()
      {
        return "End of Decode";
      }
    };
    
    void checkAndRecordError(Config & config, Classifier * classifier, Classifier::WeightedActions & weightedActions, Action * action, Errors & errors)
    {
        if (classifier->needsTrain() && ProgramParameters::errorAnalysis && (classifier->name == ProgramParameters::classifierName || ProgramParameters::classifierName.empty()))
        {
          auto zeroCostActions = classifier->getZeroCostActions(config);
          if (zeroCostActions.empty())
          {
            fprintf(stderr, "ERROR (%s) : could not find zero cost action for classifier \'%s\'. Aborting.\n", ERRINFO, classifier->name.c_str());
            config.printForDebug(stderr);
            for (auto & a : weightedActions)
            {
              fprintf(stderr, "%s : ", a.second.second.c_str());
              Oracle::explainCostOfAction(stderr, config, a.second.second);
            }
            exit(1);
          }
          std::string oAction = zeroCostActions[0];
          for (auto & s : zeroCostActions)
            if (action->name == s)
              oAction = s;
          int actionCost = classifier->getActionCost(config, action->name);
          int linkLengthPrediction = ActionBank::getLinkLength(config, action->name);
          int linkLengthGold = ActionBank::getLinkLength(config, oAction);
          errors.add({action->name, oAction, weightedActions, actionCost, linkLengthPrediction, linkLengthGold});
        }
    }
    
    void printAdvancement(Config & config, float currentSpeed)
    {
      if (ProgramParameters::interactive)
      {
        int totalSize = config.tapes[0].hyp.size();
        int steps = config.head;
        if (steps && (steps % 200 == 0 || totalSize-steps < 200))
          fprintf(stderr, "Decode : %.2f%%  speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str());
      }
    }
    
    void printDebugInfos(FILE * output, Config & config, TransitionMachine & tm, Classifier::WeightedActions & weightedActions)
    {
        if (ProgramParameters::debug)
        {
          TransitionMachine::State * currentState = tm.getCurrentState();
    
          config.printForDebug(output);
          fprintf(output, "State : \'%s\'\n", currentState->name.c_str());
    
          Classifier::printWeightedActions(output, weightedActions);
          fprintf(output, "\n");
        }
    }
    
    std::string & getClassifierAction(Config & config, Classifier::WeightedActions & weightedActions, Classifier * classifier)
    {
        std::string & predictedAction = weightedActions[0].second.second;
        Action * action = classifier->getAction(predictedAction);
    
        for(unsigned int i = 0; i < weightedActions.size(); i++)
        {
          predictedAction = weightedActions[i].second.second;
          action = classifier->getAction(predictedAction);
    
          if(weightedActions[i].first)
            break;
        }
    
        if(!action->appliable(config))
        {
          // First case the analysis is finished but without an empty stack
          if (config.head == (int)config.tapes[0].ref.size()-1)
          {
            while (!config.stackEmpty())
              config.stackPop();
            throw EndOfDecode();
          }
          else
          {
            fprintf(stderr, "ERROR (%s) : action \'%s\' is not appliable. Aborting\n", ERRINFO, predictedAction.c_str());
            exit(1);
          }
        }
    
        return predictedAction;
    }
    
    void computeSpeed(std::chrono::time_point<std::chrono::system_clock> & pastTime, int & nbActions, int & nbActionsCutoff, float & currentSpeed)
    {
      if (nbActions >= nbActionsCutoff)
      {
          auto actualTime = std::chrono::high_resolution_clock::now();
    
          double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
          currentSpeed = nbActions / seconds;
    
          pastTime = actualTime;
    
          nbActions = 0;
      }
    }
    
    void computeAndPrintSequenceEntropy(Config & config, bool & justFlipped, Errors & errors, float & entropyAccumulator, int & nbActionsInSequence)
    {
      if (ProgramParameters::printEntropy || ProgramParameters::errorAnalysis)
      {
        if (config.head >= 1 && config.getTape(ProgramParameters::sequenceDelimiterTape)[config.head-1] != ProgramParameters::sequenceDelimiter)
          justFlipped = false;
    
        if ((config.head >= 1 && config.getTape(ProgramParameters::sequenceDelimiterTape)[config.head-1] == ProgramParameters::sequenceDelimiter && !justFlipped))
        {
          justFlipped = true;
          errors.newSequence();
          entropyAccumulator /= nbActionsInSequence;
          nbActionsInSequence = 0;
          if (ProgramParameters::printEntropy)
            fprintf(stderr, "Entropy : %.2f\n", entropyAccumulator);
          entropyAccumulator = 0.0;
        }
      }
    }
    
    void computeAndRecordEntropy(Config & config, Classifier::WeightedActions & weightedActions, float & entropyAccumulator)
    {
      float entropy = Classifier::computeEntropy(weightedActions);
      config.addToEntropyHistory(entropy);
      entropyAccumulator += entropy;
    }
    
    void applyActionAndTakeTransition(TransitionMachine & tm, Action * action, Config & config)
    {
        TransitionMachine::Transition * transition = tm.getTransition(action->name);
        action->setInfos(transition->headMvt, tm.getCurrentState()->name);
        action->apply(config);
        tm.takeTransition(transition);
    }
    
    struct BeamNode
    {
      Classifier::WeightedActions weightedActions;
      double totalEntropy;
      TransitionMachine tm;
      Config config;
    
      BeamNode(TransitionMachine & tm, Config & config) : tm(tm), config(config)
      {
        totalEntropy = 0.0;
      }
    };
    
    void Decoder::decode()
    {
      float entropyAccumulator = 0.0;
      int nbActionsInSequence = 0;
      bool justFlipped = false;
      Errors errors;
      errors.newSequence();
      int nbActions = 0;
      int nbActionsCutoff = 200;
      float currentSpeed = 0.0;
      auto pastTime = std::chrono::high_resolution_clock::now();
    
      std::vector< std::shared_ptr<BeamNode> > beam;
      beam.emplace_back(new BeamNode(tm, config));
    
      auto sortBeam = [&beam]()
      {
        std::sort(beam.begin(), beam.end(), [](std::shared_ptr<BeamNode> a, std::shared_ptr<BeamNode> b)
            {
              return a->totalEntropy < b->totalEntropy;
            });
      };
    
      while (!beam[0]->config.isFinal())
      {
        for (auto & node : beam)
        {
        auto & tm = node->tm;
        auto & config = node->config;
        TransitionMachine::State * currentState = tm.getCurrentState();
        Classifier * classifier = currentState->classifier;
        config.setCurrentStateName(&currentState->name);
        Dict::currentClassifierName = classifier->name;
    
        auto weightedActions = classifier->weightActions(config);
    
        printAdvancement(config, currentSpeed);
        printDebugInfos(stderr, config, tm, weightedActions);
    
        std::string predictedAction;
        try {predictedAction = getClassifierAction(config, weightedActions, classifier);}
        catch(EndOfDecode &) {continue;};
        Action * action = classifier->getAction(predictedAction);
    
        checkAndRecordError(config, classifier, weightedActions, action, errors);
    
        applyActionAndTakeTransition(tm, action, config);
    
        nbActionsInSequence++;
        nbActions++;
        computeSpeed(pastTime, nbActions, nbActionsCutoff, currentSpeed);
        computeAndRecordEntropy(config, weightedActions, entropyAccumulator);
        computeAndPrintSequenceEntropy(config, justFlipped, errors, entropyAccumulator, nbActionsInSequence);
        }
      }
    
      if (ProgramParameters::errorAnalysis)
        errors.printStats();
      else
        beam[0]->config.printAsOutput(stdout);
    
      fprintf(stderr, "                                                     \n");
    }