Select Git revision
compute_results.py
-
Charly Lamothe authoredCharly Lamothe authored
Decoder.cpp 13.96 KiB
#include "Decoder.hpp"
#include "util.hpp"
#include "Error.hpp"
#include "ActionBank.hpp"
#include "ProgramOutput.hpp"
#include <chrono>
Decoder::Decoder(TransitionMachine & tm, Config & config)
: tm(tm), config(config)
{
}
struct EndOfDecode : public std::exception
{
const char * what() const throw()
{
return "End of Decode";
}
};
struct NoMoreActions : public std::exception
{
const char * what() const throw()
{
return "No More Actions";
}
};
bool EOSisPredicted(Config & config)
{
return config.hasTape(ProgramParameters::sequenceDelimiterTape);
}
void checkAndRecordError(Config & config, Classifier * classifier, Classifier::WeightedActions & weightedActions, std::string & action, Errors & errors)
{
if (classifier->needsTrain() && ProgramParameters::errorAnalysis && (classifier->name == ProgramParameters::classifierName || ProgramParameters::classifierName.empty()))
{
auto zeroCostActions = classifier->getZeroCostActions(config);
if (zeroCostActions.empty())
{
fprintf(stderr, "ERROR (%s) : could not find zero cost action for classifier \'%s\'. Aborting.\n", ERRINFO, classifier->name.c_str());
config.printForDebug(stderr);
for (auto & a : weightedActions)
{
fprintf(stderr, "%s : ", a.second.second.c_str());
Oracle::explainCostOfAction(stderr, config, a.second.second);
}
exit(1);
}
std::string oAction = zeroCostActions[0];
for (auto & s : zeroCostActions)
if (action == s)
oAction = s;
int actionCost = classifier->getActionCost(config, action);
int linkLengthPrediction = ActionBank::getLinkLength(config, action);
int linkLengthGold = ActionBank::getLinkLength(config, oAction);
errors.add({action, oAction, weightedActions, actionCost, linkLengthPrediction, linkLengthGold});
}
}
void printAdvancement(Config & config, float currentSpeed, int nbActionsCutoff)
{
if (ProgramParameters::interactive)
{
int totalSize = ProgramParameters::tapeSize;
int steps = config.getHead();
if (steps && (steps % nbActionsCutoff == 0 || totalSize-steps < nbActionsCutoff))
fprintf(stderr, "Decode : %.2f%% speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str());
}
}
void printDebugInfos(FILE * output, Config & config, TransitionMachine & tm, Classifier::WeightedActions & weightedActions)
{
if (ProgramParameters::debug)
{
config.printForDebug(output);
fprintf(output, "State : \'%s\'\n", tm.getCurrentState().c_str());
Classifier::printWeightedActions(output, weightedActions);
fprintf(output, "\n");
}
}
std::pair<float,std::string> getClassifierAction(Config & config, Classifier::WeightedActions & weightedActions, Classifier * classifier, unsigned int index)
{
std::string & predictedAction = weightedActions[0].second.second;
float proba = weightedActions[0].second.first;
Action * action = classifier->getAction(predictedAction);
unsigned int nbValidActions = 0;
for(unsigned int i = 0; i < weightedActions.size(); i++)
{
predictedAction = weightedActions[i].second.second;
proba = weightedActions[i].second.first;
action = classifier->getAction(predictedAction);
if(weightedActions[i].first)
{
nbValidActions++;
if (nbValidActions-1 == index)
break;
}
}
if(!action->appliable(config) || nbValidActions-1 != index)
{
// First case the analysis is finished but without an empty stack
if (config.endOfTapes())
{
while (!config.stackEmpty())
config.stackPop();
throw EndOfDecode();
}
else if (nbValidActions-1 != index)
{
throw NoMoreActions();
}
else
{
fprintf(stderr, "ERROR (%s) : action \'%s\' is not appliable. Aborting\n", ERRINFO, predictedAction.c_str());
exit(1);
}
}
return {proba, predictedAction};
}
void computeSpeed(std::chrono::time_point<std::chrono::system_clock> & pastTime, int & nbActions, int & nbActionsCutoff, float & currentSpeed)
{
if (nbActions >= nbActionsCutoff)
{
auto actualTime = std::chrono::high_resolution_clock::now();
double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
currentSpeed = nbActions / seconds;
pastTime = actualTime;
nbActions = 0;
}
}
void computeAndPrintSequenceEntropy(Config & config, bool & justFlipped, Errors & errors, float & entropyAccumulator, int & nbActionsInSequence)
{
if (!EOSisPredicted(config))
return;
if (config.getHead() >= 1 && config.getTape(ProgramParameters::sequenceDelimiterTape)[-1] != ProgramParameters::sequenceDelimiter)
justFlipped = false;
if ((config.getHead() >= 1 && config.getTape(ProgramParameters::sequenceDelimiterTape)[-1] == ProgramParameters::sequenceDelimiter && !justFlipped))
justFlipped = true;
if (justFlipped && (ProgramParameters::printEntropy || ProgramParameters::errorAnalysis))
{
justFlipped = true;
errors.newSequence();
entropyAccumulator /= nbActionsInSequence;
nbActionsInSequence = 0;
if (ProgramParameters::printEntropy)
fprintf(stderr, "Entropy : %.2f\n", entropyAccumulator);
entropyAccumulator = 0.0;
}
}
void computeAndRecordEntropy(Config & config, Classifier::WeightedActions & weightedActions, float & entropyAccumulator)
{
float entropy = Classifier::computeEntropy(weightedActions);
config.addToEntropyHistory(entropy);
entropyAccumulator += entropy;
}
void applyActionAndTakeTransition(TransitionMachine & tm, const std::string & actionName, Config & config)
{
Action * action = tm.getCurrentClassifier()->getAction(actionName);
TransitionMachine::Transition * transition = tm.getTransition(actionName);
action->setInfos(transition->headMvt, tm.getCurrentState());
action->apply(config);
tm.takeTransition(transition);
}
void Decoder::decode()
{
config.reset();
if (ProgramParameters::beamSize > 1)
decodeBeam();
else
decodeNoBeam();
ProgramOutput::instance.print(stdout);
}
void Decoder::decodeNoBeam()
{
float entropyAccumulator = 0.0;
int nbActionsInSequence = 0;
bool justFlipped = false;
Errors errors;
errors.newSequence();
int nbActions = 0;
int nbActionsCutoff = 200;
float currentSpeed = 0.0;
auto pastTime = std::chrono::high_resolution_clock::now();
FILE * outputFile = stdout;
config.setOutputFile(outputFile);
while (!config.isFinal())
{
config.setCurrentStateName(tm.getCurrentState());
Dict::currentClassifierName = tm.getCurrentClassifier()->name;
auto weightedActions = tm.getCurrentClassifier()->weightActions(config);
printAdvancement(config, currentSpeed, nbActionsCutoff);
printDebugInfos(stderr, config, tm, weightedActions);
std::pair<float,std::string> predictedAction;
try {predictedAction = getClassifierAction(config, weightedActions, tm.getCurrentClassifier(), 0);}
catch(EndOfDecode &) {continue;}
catch(NoMoreActions &) {continue;};
checkAndRecordError(config, tm.getCurrentClassifier(), weightedActions, predictedAction.second, errors);
applyActionAndTakeTransition(tm, predictedAction.second, config);
nbActionsInSequence++;
nbActions++;
computeSpeed(pastTime, nbActions, nbActionsCutoff, currentSpeed);
computeAndRecordEntropy(config, weightedActions, entropyAccumulator);
computeAndPrintSequenceEntropy(config, justFlipped, errors, entropyAccumulator, nbActionsInSequence);
}
if (ProgramParameters::errorAnalysis)
errors.printStats();
config.printTheRest();
if (ProgramParameters::interactive)
fprintf(stderr, " \n");
}
struct BeamNode
{
Classifier::WeightedActions weightedActions;
TransitionMachine tm;
Config config;
std::string action;
int nbActions;
bool justFlipped;
int lastFlippedIndex;
double getEntropy()
{
if (nbActions == 0)
return 0.0;
return config.getEntropy() / nbActions;
}
void setFlipped()
{
if (EOSisPredicted(config) && config.getHead() > lastFlippedIndex && config.getTape(ProgramParameters::sequenceDelimiterTape)[-1] == ProgramParameters::sequenceDelimiter)
{
justFlipped = true;
lastFlippedIndex = config.getHead();
}
else
{
justFlipped = false;
}
}
BeamNode(TransitionMachine & tm, Config & config) : tm(tm), config(config)
{
justFlipped = false;
nbActions = 0;
lastFlippedIndex = 0;
config.setOutputFile(nullptr);
config.setEntropy(0.0);
}
BeamNode(BeamNode & other, const std::string & action, float proba) : tm(other.tm), config(other.config)
{
justFlipped = false;
lastFlippedIndex = other.lastFlippedIndex;
this->action = action;
nbActions = other.nbActions + 1;
config.setOutputFile(nullptr);
config.addToEntropy(proba);
}
};
void Decoder::decodeBeam()
{
float entropyAccumulator = 0.0;
int nbActionsInSequence = 0;
bool justFlipped = false;
Errors errors;
errors.newSequence();
int nbActions = 0;
int nbActionsCutoff = 200;
float currentSpeed = 0.0;
auto pastTime = std::chrono::high_resolution_clock::now();
FILE * outputFile = stdout;
std::vector< std::shared_ptr<BeamNode> > beam;
std::vector< std::shared_ptr<BeamNode> > otherBeam;
std::vector< std::shared_ptr<BeamNode> > justFlippedBeam;
beam.emplace_back(new BeamNode(tm, config));
auto sortBeam = [&beam]()
{
std::sort(beam.begin(), beam.end(), [](const std::shared_ptr<BeamNode> & a, const std::shared_ptr<BeamNode> & b)
{
return a->getEntropy() > b->getEntropy();
});
};
auto printBeam = [](std::vector< std::shared_ptr<BeamNode> > & beam)
{
for (auto & node : beam)
{
node->config.printForDebug(stderr);
fprintf(stderr, "action : %s\n", node->action.c_str());
fprintf(stderr, "nbActions : %d\n", node->nbActions);
fprintf(stderr, "justFlipped : %s\n", node->justFlipped ? "true" : "false");
fprintf(stderr, "lastFlippedIndex : %d\n", node->lastFlippedIndex);
fprintf(stderr, "--------------------------------------------------------------------------------\n");
}
};
bool endOfDecode = false;
while (endOfDecode == false)
{
otherBeam.clear();
bool mustContinue = false;
for (auto & node : beam)
{
node->config.setCurrentStateName(node->tm.getCurrentState());
Dict::currentClassifierName = node->tm.getCurrentClassifier()->name;
node->weightedActions = node->tm.getCurrentClassifier()->weightActions(node->config);
printAdvancement(node->config, currentSpeed, nbActionsCutoff);
unsigned int nbActionsMax = std::min(std::max(node->tm.getCurrentClassifier()->getNbActions(),(unsigned int)1),(unsigned int)ProgramParameters::nbChilds);
for (unsigned int actionIndex = 0; actionIndex < nbActionsMax; actionIndex++)
{
std::pair<float,std::string> predictedAction;
try {predictedAction = getClassifierAction(node->config, node->weightedActions, node->tm.getCurrentClassifier(), actionIndex);}
catch(EndOfDecode &) {mustContinue = true; break;}
catch(NoMoreActions &) {break;};
otherBeam.emplace_back(new BeamNode(*node.get(), predictedAction.second, predictedAction.first));
}
if (mustContinue)
break;
}
if (ProgramParameters::debug)
{
fprintf(stderr, "################################# Beam before sort #################################\n");
printBeam(otherBeam);
fprintf(stderr, "####################################################################################\n");
}
beam = otherBeam;
sortBeam();
beam.resize(std::min((int)beam.size(), ProgramParameters::beamSize));
if (beam.empty())
{
fprintf(stderr, "ERROR (%s) : beam is empty. Aborting.\n", ERRINFO);
exit(1);
}
if (ProgramParameters::debug)
{
fprintf(stderr, "################################# Beam after sort #################################\n");
printBeam(beam);
fprintf(stderr, "###################################################################################\n");
}
for (auto & node : beam)
node->config.setOutputFile(outputFile);
for (auto & node : beam)
{
config.setCurrentStateName(node->tm.getCurrentState());
Dict::currentClassifierName = node->tm.getCurrentClassifier()->name;
if (node.get() == beam.begin()->get())
checkAndRecordError(node->config, node->tm.getCurrentClassifier(), node->weightedActions, node->action, errors);
applyActionAndTakeTransition(node->tm, node->action, node->config);
if (node.get() == beam.begin()->get())
{
nbActionsInSequence++;
nbActions++;
computeSpeed(pastTime, nbActions, nbActionsCutoff, currentSpeed);
computeAndRecordEntropy(node->config, node->weightedActions, entropyAccumulator);
computeAndPrintSequenceEntropy(node->config, justFlipped, errors, entropyAccumulator, nbActionsInSequence);
}
node->setFlipped();
}
for (unsigned int i = 0; i < beam.size(); i++)
{
if (beam[i]->justFlipped)
{
justFlippedBeam.push_back(beam[i]);
beam[i] = beam[beam.size()-1];
beam.pop_back();
i--;
}
}
if ((int)justFlippedBeam.size() >= ProgramParameters::beamSize || (justFlippedBeam.size() && mustContinue))
{
if (mustContinue || justFlippedBeam[0]->config.endOfTapes())
endOfDecode = true;
beam = justFlippedBeam;
justFlippedBeam.clear();
sortBeam();
beam.resize(1);
beam[0]->config.setEntropy(0.0);
beam[0]->nbActions = 0;
}
if (!EOSisPredicted(beam[0]->config) && beam[0]->config.endOfTapes())
endOfDecode = true;
}
if (ProgramParameters::errorAnalysis)
errors.printStats();
for (auto node : beam)
{
node->config.setOutputFile(outputFile);
node->config.printTheRest();
}
if (ProgramParameters::interactive)
fprintf(stderr, " \n");
}