Skip to content
Snippets Groups Projects
Commit 535710e6 authored by Franck Dary's avatar Franck Dary
Browse files

Refactored code of Decoder

parent 7d08c2a0
Branches
No related tags found
No related merge requests found
...@@ -9,34 +9,43 @@ Decoder::Decoder(TransitionMachine & tm, Config & config) ...@@ -9,34 +9,43 @@ Decoder::Decoder(TransitionMachine & tm, Config & config)
{ {
} }
void Decoder::decode() struct EndOfDecode : public std::exception
{ {
float entropyAccumulator = 0.0; const char * what() const throw()
int nbActionsInSequence = 0;
bool justFlipped = false;
Errors errors;
errors.newSequence();
int nbActions = 0;
int nbActionsCutoff = 200;
float currentSpeed = 0.0;
auto pastTime = std::chrono::high_resolution_clock::now();
while (!config.isFinal())
{ {
TransitionMachine::State * currentState = tm.getCurrentState(); return "End of Decode";
Classifier * classifier = currentState->classifier; }
config.setCurrentStateName(&currentState->name); };
Dict::currentClassifierName = classifier->name;
if (ProgramParameters::debug) void checkAndRecordError(Config & config, Classifier * classifier, Classifier::WeightedActions & weightedActions, Action * action, Errors & errors)
{
if (classifier->needsTrain() && ProgramParameters::errorAnalysis && (classifier->name == ProgramParameters::classifierName || ProgramParameters::classifierName.empty()))
{
auto zeroCostActions = classifier->getZeroCostActions(config);
if (zeroCostActions.empty())
{ {
fprintf(stderr, "ERROR (%s) : could not find zero cost action for classifier \'%s\'. Aborting.\n", ERRINFO, classifier->name.c_str());
config.printForDebug(stderr); config.printForDebug(stderr);
fprintf(stderr, "State : \'%s\'\n", currentState->name.c_str()); for (auto & a : weightedActions)
{
fprintf(stderr, "%s : ", a.second.second.c_str());
Oracle::explainCostOfAction(stderr, config, a.second.second);
}
exit(1);
}
std::string oAction = zeroCostActions[0];
for (auto & s : zeroCostActions)
if (action->name == s)
oAction = s;
int actionCost = classifier->getActionCost(config, action->name);
int linkLengthPrediction = ActionBank::getLinkLength(config, action->name);
int linkLengthGold = ActionBank::getLinkLength(config, oAction);
errors.add({action->name, oAction, weightedActions, actionCost, linkLengthPrediction, linkLengthGold});
}
} }
auto weightedActions = classifier->weightActions(config); void printAdvancement(Config & config, float currentSpeed)
{
// Print current iter advancement in percentage
if (ProgramParameters::interactive) if (ProgramParameters::interactive)
{ {
int totalSize = config.tapes[0].hyp.size(); int totalSize = config.tapes[0].hyp.size();
...@@ -44,13 +53,24 @@ void Decoder::decode() ...@@ -44,13 +53,24 @@ void Decoder::decode()
if (steps && (steps % 200 == 0 || totalSize-steps < 200)) if (steps && (steps % 200 == 0 || totalSize-steps < 200))
fprintf(stderr, "Decode : %.2f%% speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str()); fprintf(stderr, "Decode : %.2f%% speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str());
} }
}
void printDebugInfos(FILE * output, Config & config, TransitionMachine & tm, Classifier::WeightedActions & weightedActions)
{
if (ProgramParameters::debug) if (ProgramParameters::debug)
{ {
Classifier::printWeightedActions(stderr, weightedActions); TransitionMachine::State * currentState = tm.getCurrentState();
fprintf(stderr, "\n");
config.printForDebug(output);
fprintf(output, "State : \'%s\'\n", currentState->name.c_str());
Classifier::printWeightedActions(output, weightedActions);
fprintf(output, "\n");
}
} }
std::string & getClassifierAction(Config & config, Classifier::WeightedActions & weightedActions, Classifier * classifier)
{
std::string & predictedAction = weightedActions[0].second.second; std::string & predictedAction = weightedActions[0].second.second;
Action * action = classifier->getAction(predictedAction); Action * action = classifier->getAction(predictedAction);
...@@ -70,7 +90,7 @@ void Decoder::decode() ...@@ -70,7 +90,7 @@ void Decoder::decode()
{ {
while (!config.stackEmpty()) while (!config.stackEmpty())
config.stackPop(); config.stackPop();
continue; throw EndOfDecode();
} }
else else
{ {
...@@ -79,44 +99,11 @@ void Decoder::decode() ...@@ -79,44 +99,11 @@ void Decoder::decode()
} }
} }
if (classifier->needsTrain() && ProgramParameters::errorAnalysis && (classifier->name == ProgramParameters::classifierName || ProgramParameters::classifierName.empty())) return predictedAction;
{
auto zeroCostActions = classifier->getZeroCostActions(config);
if (zeroCostActions.empty())
{
fprintf(stderr, "ERROR (%s) : could not find zero cost action for classifier \'%s\'. Aborting.\n", ERRINFO, classifier->name.c_str());
config.printForDebug(stderr);
for (auto & a : weightedActions)
{
fprintf(stderr, "%s : ", a.second.second.c_str());
Oracle::explainCostOfAction(stderr, config, a.second.second);
}
exit(1);
}
std::string oAction = zeroCostActions[0];
for (auto & s : zeroCostActions)
if (action->name == s)
oAction = s;
int actionCost = classifier->getActionCost(config, action->name);
int linkLengthPrediction = ActionBank::getLinkLength(config, action->name);
int linkLengthGold = ActionBank::getLinkLength(config, oAction);
errors.add({action->name, oAction, weightedActions, actionCost, linkLengthPrediction, linkLengthGold});
} }
TransitionMachine::Transition * transition = tm.getTransition(predictedAction); void computeSpeed(std::chrono::time_point<std::chrono::system_clock> & pastTime, int & nbActions, int & nbActionsCutoff, float & currentSpeed)
{
action->setInfos(transition->headMvt, currentState->name);
action->apply(config);
tm.takeTransition(transition);
float entropy = Classifier::computeEntropy(weightedActions);
config.addToEntropyHistory(entropy);
nbActionsInSequence++;
nbActions++;
if (nbActions >= nbActionsCutoff) if (nbActions >= nbActionsCutoff)
{ {
auto actualTime = std::chrono::high_resolution_clock::now(); auto actualTime = std::chrono::high_resolution_clock::now();
...@@ -128,9 +115,10 @@ void Decoder::decode() ...@@ -128,9 +115,10 @@ void Decoder::decode()
nbActions = 0; nbActions = 0;
} }
}
entropyAccumulator += entropy; void computeAndPrintSequenceEntropy(Config & config, bool & justFlipped, Errors & errors, float & entropyAccumulator, int & nbActionsInSequence)
{
if (ProgramParameters::printEntropy || ProgramParameters::errorAnalysis) if (ProgramParameters::printEntropy || ProgramParameters::errorAnalysis)
{ {
if (config.head >= 1 && config.getTape(ProgramParameters::sequenceDelimiterTape)[config.head-1] != ProgramParameters::sequenceDelimiter) if (config.head >= 1 && config.getTape(ProgramParameters::sequenceDelimiterTape)[config.head-1] != ProgramParameters::sequenceDelimiter)
...@@ -147,7 +135,61 @@ void Decoder::decode() ...@@ -147,7 +135,61 @@ void Decoder::decode()
entropyAccumulator = 0.0; entropyAccumulator = 0.0;
} }
} }
}
void computeAndRecordEntropy(Config & config, Classifier::WeightedActions & weightedActions, float & entropyAccumulator)
{
float entropy = Classifier::computeEntropy(weightedActions);
config.addToEntropyHistory(entropy);
entropyAccumulator += entropy;
}
void applyActionAndTakeTransition(TransitionMachine & tm, Action * action, Config & config)
{
TransitionMachine::Transition * transition = tm.getTransition(action->name);
action->setInfos(transition->headMvt, tm.getCurrentState()->name);
action->apply(config);
tm.takeTransition(transition);
}
void Decoder::decode()
{
float entropyAccumulator = 0.0;
int nbActionsInSequence = 0;
bool justFlipped = false;
Errors errors;
errors.newSequence();
int nbActions = 0;
int nbActionsCutoff = 200;
float currentSpeed = 0.0;
auto pastTime = std::chrono::high_resolution_clock::now();
while (!config.isFinal())
{
TransitionMachine::State * currentState = tm.getCurrentState();
Classifier * classifier = currentState->classifier;
config.setCurrentStateName(&currentState->name);
Dict::currentClassifierName = classifier->name;
auto weightedActions = classifier->weightActions(config);
printAdvancement(config, currentSpeed);
printDebugInfos(stderr, config, tm, weightedActions);
std::string predictedAction;
try {predictedAction = getClassifierAction(config, weightedActions, classifier);}
catch(EndOfDecode &) {continue;};
Action * action = classifier->getAction(predictedAction);
checkAndRecordError(config, classifier, weightedActions, action, errors);
applyActionAndTakeTransition(tm, action, config);
nbActionsInSequence++;
nbActions++;
computeSpeed(pastTime, nbActions, nbActionsCutoff, currentSpeed);
computeAndRecordEntropy(config, weightedActions, entropyAccumulator);
computeAndPrintSequenceEntropy(config, justFlipped, errors, entropyAccumulator, nbActionsInSequence);
} }
if (ProgramParameters::errorAnalysis) if (ProgramParameters::errorAnalysis)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment