Skip to content
Snippets Groups Projects
Commit e98bb937 authored by Franck Dary's avatar Franck Dary
Browse files

First implementation of beam search

parent 03817c3c
No related branches found
No related tags found
No related merge requests found
......@@ -19,6 +19,13 @@ class Decoder
/// @brief The current configuration of the TransitionMachine
Config & config;
private :
/// @brief Fill bd using tm, with beam serach.
void decodeBeam();
/// @brief Fill bd using tm, without beam search.
void decodeNoBeam();
public :
/// @brief Use tm to fill bd.
......
......@@ -17,6 +17,14 @@ struct EndOfDecode : public std::exception
}
};
struct NoMoreActions : public std::exception
{
const char * what() const throw()
{
return "No More Actions";
}
};
void checkAndRecordError(Config & config, Classifier * classifier, Classifier::WeightedActions & weightedActions, Action * action, Errors & errors)
{
if (classifier->needsTrain() && ProgramParameters::errorAnalysis && (classifier->name == ProgramParameters::classifierName || ProgramParameters::classifierName.empty()))
......@@ -69,21 +77,28 @@ void printDebugInfos(FILE * output, Config & config, TransitionMachine & tm, Cla
}
}
std::string & getClassifierAction(Config & config, Classifier::WeightedActions & weightedActions, Classifier * classifier)
std::pair<float,std::string> getClassifierAction(Config & config, Classifier::WeightedActions & weightedActions, Classifier * classifier, unsigned int index)
{
std::string & predictedAction = weightedActions[0].second.second;
float proba = weightedActions[0].second.first;
Action * action = classifier->getAction(predictedAction);
unsigned int nbValidActions = 0;
for(unsigned int i = 0; i < weightedActions.size(); i++)
{
predictedAction = weightedActions[i].second.second;
proba = weightedActions[i].second.first;
action = classifier->getAction(predictedAction);
if(weightedActions[i].first)
{
nbValidActions++;
if (nbValidActions-1 == index)
break;
}
}
if(!action->appliable(config))
if(!action->appliable(config) || nbValidActions-1 != index)
{
// First case the analysis is finished but without an empty stack
if (config.head == (int)config.tapes[0].ref.size()-1)
......@@ -92,6 +107,10 @@ std::string & getClassifierAction(Config & config, Classifier::WeightedActions &
config.stackPop();
throw EndOfDecode();
}
else if (nbValidActions-1 != index)
{
throw NoMoreActions();
}
else
{
fprintf(stderr, "ERROR (%s) : action \'%s\' is not appliable. Aborting\n", ERRINFO, predictedAction.c_str());
......@@ -99,7 +118,7 @@ std::string & getClassifierAction(Config & config, Classifier::WeightedActions &
}
}
return predictedAction;
return {proba, predictedAction};
}
void computeSpeed(std::chrono::time_point<std::chrono::system_clock> & pastTime, int & nbActions, int & nbActionsCutoff, float & currentSpeed)
......@@ -158,14 +177,78 @@ struct BeamNode
double totalEntropy;
TransitionMachine tm;
Config config;
Action * action;
BeamNode(TransitionMachine & tm, Config & config) : tm(tm), config(config)
{
totalEntropy = 0.0;
}
BeamNode(BeamNode & other, Action * action, float proba) : tm(other.tm), config(other.config)
{
totalEntropy = other.totalEntropy + proba;
this->action = action;
}
};
void Decoder::decode()
{
if (ProgramParameters::beamSize > 1)
decodeBeam();
else
decodeNoBeam();
}
void Decoder::decodeNoBeam()
{
float entropyAccumulator = 0.0;
int nbActionsInSequence = 0;
bool justFlipped = false;
Errors errors;
errors.newSequence();
int nbActions = 0;
int nbActionsCutoff = 200;
float currentSpeed = 0.0;
auto pastTime = std::chrono::high_resolution_clock::now();
while (!config.isFinal())
{
TransitionMachine::State * currentState = tm.getCurrentState();
Classifier * classifier = currentState->classifier;
config.setCurrentStateName(&currentState->name);
Dict::currentClassifierName = classifier->name;
auto weightedActions = classifier->weightActions(config);
printAdvancement(config, currentSpeed);
printDebugInfos(stderr, config, tm, weightedActions);
std::pair<float,std::string> predictedAction;
try {predictedAction = getClassifierAction(config, weightedActions, classifier, 0);}
catch(EndOfDecode &) {continue;}
catch(NoMoreActions &) {continue;};
Action * action = classifier->getAction(predictedAction.second);
checkAndRecordError(config, classifier, weightedActions, action, errors);
applyActionAndTakeTransition(tm, action, config);
nbActionsInSequence++;
nbActions++;
computeSpeed(pastTime, nbActions, nbActionsCutoff, currentSpeed);
computeAndRecordEntropy(config, weightedActions, entropyAccumulator);
computeAndPrintSequenceEntropy(config, justFlipped, errors, entropyAccumulator, nbActionsInSequence);
}
if (ProgramParameters::errorAnalysis)
errors.printStats();
else
config.printAsOutput(stdout);
fprintf(stderr, " \n");
}
void Decoder::decodeBeam()
{
float entropyAccumulator = 0.0;
int nbActionsInSequence = 0;
......@@ -178,18 +261,22 @@ void Decoder::decode()
auto pastTime = std::chrono::high_resolution_clock::now();
std::vector< std::shared_ptr<BeamNode> > beam;
std::vector< std::shared_ptr<BeamNode> > otherBeam;
beam.emplace_back(new BeamNode(tm, config));
auto sortBeam = [&beam]()
{
std::sort(beam.begin(), beam.end(), [](std::shared_ptr<BeamNode> a, std::shared_ptr<BeamNode> b)
{
return a->totalEntropy < b->totalEntropy;
return a->totalEntropy > b->totalEntropy;
});
};
while (!beam[0]->config.isFinal())
{
otherBeam.clear();
bool mustContinue = false;
for (auto & node : beam)
{
auto & tm = node->tm;
......@@ -199,27 +286,58 @@ void Decoder::decode()
config.setCurrentStateName(&currentState->name);
Dict::currentClassifierName = classifier->name;
auto weightedActions = classifier->weightActions(config);
node->weightedActions = classifier->weightActions(config);
printAdvancement(config, currentSpeed);
printDebugInfos(stderr, config, tm, weightedActions);
printDebugInfos(stderr, config, tm, node->weightedActions);
std::string predictedAction;
try {predictedAction = getClassifierAction(config, weightedActions, classifier);}
catch(EndOfDecode &) {continue;};
Action * action = classifier->getAction(predictedAction);
unsigned int nbActionsMax = std::min(std::max(classifier->getNbActions(),(unsigned int)1),(unsigned int)ProgramParameters::nbChilds);
for (unsigned int actionIndex = 0; actionIndex < nbActionsMax; actionIndex++)
{
std::pair<float,std::string> predictedAction;
try {predictedAction = getClassifierAction(config, node->weightedActions, classifier, actionIndex);}
catch(EndOfDecode &) {mustContinue = true; break;}
catch(NoMoreActions &) {break;};
otherBeam.emplace_back(new BeamNode(*node.get(),classifier->getAction(predictedAction.second), predictedAction.first));
}
checkAndRecordError(config, classifier, weightedActions, action, errors);
if (mustContinue)
break;
}
applyActionAndTakeTransition(tm, action, config);
if (mustContinue)
continue;
beam = otherBeam;
sortBeam();
beam.resize(std::min((int)beam.size(), ProgramParameters::beamSize));
for (auto & node : beam)
{
auto & tm = node->tm;
auto & config = node->config;
TransitionMachine::State * currentState = tm.getCurrentState();
Classifier * classifier = currentState->classifier;
config.setCurrentStateName(&currentState->name);
Dict::currentClassifierName = classifier->name;
if (node.get() == beam.begin()->get())
{
checkAndRecordError(config, classifier, node->weightedActions, node->action, errors);
}
applyActionAndTakeTransition(tm, node->action, config);
if (node.get() == beam.begin()->get())
{
nbActionsInSequence++;
nbActions++;
computeSpeed(pastTime, nbActions, nbActionsCutoff, currentSpeed);
computeAndRecordEntropy(config, weightedActions, entropyAccumulator);
computeAndRecordEntropy(config, node->weightedActions, entropyAccumulator);
computeAndPrintSequenceEntropy(config, justFlipped, errors, entropyAccumulator, nbActionsInSequence);
}
}
}
if (ProgramParameters::errorAnalysis)
errors.printStats();
......
......@@ -64,7 +64,14 @@ po::options_description getOptionsDescription()
("classifier", po::value<std::string>()->default_value(""),
"Name of the monitored classifier, if not specified monitor everyone");
desc.add(req).add(opt).add(analysis);
po::options_description beam("Beam search related options");
beam.add_options()
("beamSize", po::value<int>()->default_value(1),
"Number of nodes to explore for each depth of the tree of all the possible configurations")
("nbChilds", po::value<int>()->default_value(3),
"Number of childs to consider for each explored node");
desc.add(req).add(opt).add(analysis).add(beam);
return desc;
}
......@@ -134,6 +141,8 @@ int main(int argc, char * argv[])
ProgramParameters::sequenceDelimiterTape = vm["sequenceDelimiterTape"].as<std::string>();
ProgramParameters::sequenceDelimiter = vm["sequenceDelimiter"].as<std::string>();
ProgramParameters::showFeatureRepresentation = vm["showFeatureRepresentation"].as<int>();
ProgramParameters::beamSize = vm["beamSize"].as<int>();
ProgramParameters::nbChilds = vm["nbChilds"].as<int>();
ProgramParameters::optimizer = "none";
std::string featureModels = vm["featureModels"].as<std::string>();
if (!featureModels.empty())
......
......@@ -61,6 +61,8 @@ struct ProgramParameters
static std::map<std::string,std::string> featureModelByClassifier;
static int nbErrorsToShow;
static int nbIndividuals;
static int beamSize;
static int nbChilds;
private :
......
......@@ -55,4 +55,6 @@ std::string ProgramParameters::loss;
std::map<std::string,std::string> ProgramParameters::featureModelByClassifier;
int ProgramParameters::nbErrorsToShow;
int ProgramParameters::nbIndividuals;
int ProgramParameters::beamSize;
int ProgramParameters::nbChilds;
......@@ -63,6 +63,10 @@ class ActionSet
///
/// @return A pointer to the Action.
Action * getAction(const std::string & name);
/// @brief Get the number of actions contained in this set.
///
/// @return The number of actions in this set.
unsigned int size();
};
#endif
......@@ -189,6 +189,10 @@ class Classifier
///
/// @return The index of the Action.
int getActionIndex(const std::string & action);
/// @brief Get the number of actions this classifier knows.
///
/// @return The number of actions.
unsigned int getNbActions();
};
#endif
......@@ -86,3 +86,8 @@ Action * ActionSet::getAction(const std::string & name)
return &actions[getActionIndex(name)];
}
unsigned int ActionSet::size()
{
return actions.size();
}
......@@ -314,3 +314,8 @@ NeuralNetwork * Classifier::createNeuralNetwork(const std::string & modelFilenam
return new MLP(modelFilename);
}
unsigned int Classifier::getNbActions()
{
return as->size();
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment