Skip to content
Snippets Groups Projects
Select Git revision
  • 359dd3b50623940a3cb488265cdc96f4140ad2a3
  • master default protected
  • fullUD
  • movementInAction
4 results

Oracle.cpp

Blame
  • Decoder.cpp 15.02 KiB
    #include "Decoder.hpp"
    #include "util.hpp"
    #include "Error.hpp"
    #include "ActionBank.hpp"
    #include "ProgramOutput.hpp"
    #include <chrono>
    
    Decoder::Decoder(TransitionMachine & tm, Config & config)
    : tm(tm), config(config)
    {
    }
    
    struct EndOfDecode : public std::exception
    {
      const char * what() const throw()
      {
        return "End of Decode";
      }
    };
    
    struct NoMoreActions : public std::exception
    {
      const char * what() const throw()
      {
        return "No More Actions";
      }
    };
    
    bool EOSisPredicted(Config & config)
    {
      return config.hasTape(ProgramParameters::sequenceDelimiterTape);
    }
    
    void checkAndRecordError(Config & config, Classifier * classifier, Classifier::WeightedActions & weightedActions, std::string & action, Errors & errors)
    {
        if (classifier->needsTrain() && ProgramParameters::errorAnalysis && (classifier->name == ProgramParameters::classifierName || ProgramParameters::classifierName.empty()))
        {
          auto zeroCostActions = classifier->getZeroCostActions(config);
          if (zeroCostActions.empty())
          {
            fprintf(stderr, "ERROR (%s) : could not find zero cost action for classifier \'%s\'. Aborting.\n", ERRINFO, classifier->name.c_str());
            config.printForDebug(stderr);
            for (auto & a : weightedActions)
            {
              fprintf(stderr, "%s : ", a.second.second.c_str());
              Oracle::explainCostOfAction(stderr, config, a.second.second);
            }
            exit(1);
          }
          std::string oAction = zeroCostActions[0];
          for (auto & s : zeroCostActions)
            if (action == s)
              oAction = s;
          int actionCost = classifier->getActionCost(config, action);
          int linkLengthPrediction = ActionBank::getLinkLength(config, action);
          int linkLengthGold = ActionBank::getLinkLength(config, oAction);
          errors.add({action, oAction, weightedActions, actionCost, linkLengthPrediction, linkLengthGold});
        }
    }
    
    void printAdvancement(Config & config, float currentSpeed, int nbActionsCutoff)
    {
      if (ProgramParameters::interactive)
      {
        int totalSize = ProgramParameters::tapeSize;
        int steps = config.getHead();
        if (steps && (steps % nbActionsCutoff == 0 || totalSize-steps < nbActionsCutoff))
          fprintf(stderr, "Decode : %.2f%%  speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str());
      }
    }