Select Git revision
Decoder.cpp 7.39 KiB
#include "Decoder.hpp"
#include "util.hpp"
#include "Error.hpp"
#include "ActionBank.hpp"
#include <chrono>
Decoder::Decoder(TransitionMachine & tm, Config & config)
: tm(tm), config(config)
{
}
struct EndOfDecode : public std::exception
{
const char * what() const throw()
{
return "End of Decode";
}
};
void checkAndRecordError(Config & config, Classifier * classifier, Classifier::WeightedActions & weightedActions, Action * action, Errors & errors)
{
if (classifier->needsTrain() && ProgramParameters::errorAnalysis && (classifier->name == ProgramParameters::classifierName || ProgramParameters::classifierName.empty()))
{
auto zeroCostActions = classifier->getZeroCostActions(config);
if (zeroCostActions.empty())
{
fprintf(stderr, "ERROR (%s) : could not find zero cost action for classifier \'%s\'. Aborting.\n", ERRINFO, classifier->name.c_str());
config.printForDebug(stderr);
for (auto & a : weightedActions)
{
fprintf(stderr, "%s : ", a.second.second.c_str());
Oracle::explainCostOfAction(stderr, config, a.second.second);
}
exit(1);
}
std::string oAction = zeroCostActions[0];
for (auto & s : zeroCostActions)
if (action->name == s)
oAction = s;
int actionCost = classifier->getActionCost(config, action->name);
int linkLengthPrediction = ActionBank::getLinkLength(config, action->name);
int linkLengthGold = ActionBank::getLinkLength(config, oAction);
errors.add({action->name, oAction, weightedActions, actionCost, linkLengthPrediction, linkLengthGold});
}
}
void printAdvancement(Config & config, float currentSpeed)
{
if (ProgramParameters::interactive)
{
int totalSize = config.tapes[0].hyp.size();
int steps = config.head;
if (steps && (steps % 200 == 0 || totalSize-steps < 200))
fprintf(stderr, "Decode : %.2f%% speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str());
}
}
void printDebugInfos(FILE * output, Config & config, TransitionMachine & tm, Classifier::WeightedActions & weightedActions)
{
if (ProgramParameters::debug)
{
TransitionMachine::State * currentState = tm.getCurrentState();
config.printForDebug(output);
fprintf(output, "State : \'%s\'\n", currentState->name.c_str());
Classifier::printWeightedActions(output, weightedActions);
fprintf(output, "\n");
}
}
std::string & getClassifierAction(Config & config, Classifier::WeightedActions & weightedActions, Classifier * classifier)
{
std::string & predictedAction = weightedActions[0].second.second;
Action * action = classifier->getAction(predictedAction);
for(unsigned int i = 0; i < weightedActions.size(); i++)
{
predictedAction = weightedActions[i].second.second;
action = classifier->getAction(predictedAction);
if(weightedActions[i].first)
break;
}
if(!action->appliable(config))
{
// First case the analysis is finished but without an empty stack
if (config.head == (int)config.tapes[0].ref.size()-1)
{
while (!config.stackEmpty())
config.stackPop();
throw EndOfDecode();
}
else
{
fprintf(stderr, "ERROR (%s) : action \'%s\' is not appliable. Aborting\n", ERRINFO, predictedAction.c_str());
exit(1);
}
}
return predictedAction;
}
void computeSpeed(std::chrono::time_point<std::chrono::system_clock> & pastTime, int & nbActions, int & nbActionsCutoff, float & currentSpeed)
{
if (nbActions >= nbActionsCutoff)
{
auto actualTime = std::chrono::high_resolution_clock::now();
double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
currentSpeed = nbActions / seconds;
pastTime = actualTime;
nbActions = 0;
}
}
void computeAndPrintSequenceEntropy(Config & config, bool & justFlipped, Errors & errors, float & entropyAccumulator, int & nbActionsInSequence)
{
if (ProgramParameters::printEntropy || ProgramParameters::errorAnalysis)
{
if (config.head >= 1 && config.getTape(ProgramParameters::sequenceDelimiterTape)[config.head-1] != ProgramParameters::sequenceDelimiter)
justFlipped = false;
if ((config.head >= 1 && config.getTape(ProgramParameters::sequenceDelimiterTape)[config.head-1] == ProgramParameters::sequenceDelimiter && !justFlipped))
{
justFlipped = true;
errors.newSequence();
entropyAccumulator /= nbActionsInSequence;
nbActionsInSequence = 0;
if (ProgramParameters::printEntropy)
fprintf(stderr, "Entropy : %.2f\n", entropyAccumulator);
entropyAccumulator = 0.0;
}
}
}
void computeAndRecordEntropy(Config & config, Classifier::WeightedActions & weightedActions, float & entropyAccumulator)
{
float entropy = Classifier::computeEntropy(weightedActions);
config.addToEntropyHistory(entropy);
entropyAccumulator += entropy;
}
void applyActionAndTakeTransition(TransitionMachine & tm, Action * action, Config & config)
{
TransitionMachine::Transition * transition = tm.getTransition(action->name);
action->setInfos(transition->headMvt, tm.getCurrentState()->name);
action->apply(config);
tm.takeTransition(transition);
}
struct BeamNode
{
Classifier::WeightedActions weightedActions;
double totalEntropy;
TransitionMachine tm;
Config config;
BeamNode(TransitionMachine & tm, Config & config) : tm(tm), config(config)
{
totalEntropy = 0.0;
}
};
void Decoder::decode()
{
float entropyAccumulator = 0.0;
int nbActionsInSequence = 0;
bool justFlipped = false;
Errors errors;
errors.newSequence();
int nbActions = 0;
int nbActionsCutoff = 200;
float currentSpeed = 0.0;
auto pastTime = std::chrono::high_resolution_clock::now();
std::vector< std::shared_ptr<BeamNode> > beam;
beam.emplace_back(new BeamNode(tm, config));
auto sortBeam = [&beam]()
{
std::sort(beam.begin(), beam.end(), [](std::shared_ptr<BeamNode> a, std::shared_ptr<BeamNode> b)
{
return a->totalEntropy < b->totalEntropy;
});
};
while (!beam[0]->config.isFinal())
{
for (auto & node : beam)
{
auto & tm = node->tm;
auto & config = node->config;
TransitionMachine::State * currentState = tm.getCurrentState();
Classifier * classifier = currentState->classifier;
config.setCurrentStateName(¤tState->name);
Dict::currentClassifierName = classifier->name;
auto weightedActions = classifier->weightActions(config);
printAdvancement(config, currentSpeed);
printDebugInfos(stderr, config, tm, weightedActions);
std::string predictedAction;
try {predictedAction = getClassifierAction(config, weightedActions, classifier);}
catch(EndOfDecode &) {continue;};
Action * action = classifier->getAction(predictedAction);
checkAndRecordError(config, classifier, weightedActions, action, errors);
applyActionAndTakeTransition(tm, action, config);
nbActionsInSequence++;
nbActions++;
computeSpeed(pastTime, nbActions, nbActionsCutoff, currentSpeed);
computeAndRecordEntropy(config, weightedActions, entropyAccumulator);
computeAndPrintSequenceEntropy(config, justFlipped, errors, entropyAccumulator, nbActionsInSequence);
}
}
if (ProgramParameters::errorAnalysis)
errors.printStats();
else
beam[0]->config.printAsOutput(stdout);
fprintf(stderr, " \n");
}