Skip to content
Snippets Groups Projects
Commit 6c8c8e4a authored by Franck Dary's avatar Franck Dary
Browse files

Started implementation of BeamSearch

parent 535710e6
No related branches found
No related tags found
No related merge requests found
......@@ -152,6 +152,19 @@ void applyActionAndTakeTransition(TransitionMachine & tm, Action * action, Confi
tm.takeTransition(transition);
}
struct BeamNode
{
Classifier::WeightedActions weightedActions;
double totalEntropy;
TransitionMachine tm;
Config config;
BeamNode(TransitionMachine & tm, Config & config) : tm(tm), config(config)
{
totalEntropy = 0.0;
}
};
void Decoder::decode()
{
float entropyAccumulator = 0.0;
......@@ -164,8 +177,23 @@ void Decoder::decode()
float currentSpeed = 0.0;
auto pastTime = std::chrono::high_resolution_clock::now();
while (!config.isFinal())
std::vector< std::shared_ptr<BeamNode> > beam;
beam.emplace_back(new BeamNode(tm, config));
auto sortBeam = [&beam]()
{
std::sort(beam.begin(), beam.end(), [](std::shared_ptr<BeamNode> a, std::shared_ptr<BeamNode> b)
{
return a->totalEntropy < b->totalEntropy;
});
};
while (!beam[0]->config.isFinal())
{
for (auto & node : beam)
{
auto & tm = node->tm;
auto & config = node->config;
TransitionMachine::State * currentState = tm.getCurrentState();
Classifier * classifier = currentState->classifier;
config.setCurrentStateName(&currentState->name);
......@@ -191,11 +219,12 @@ void Decoder::decode()
computeAndRecordEntropy(config, weightedActions, entropyAccumulator);
computeAndPrintSequenceEntropy(config, justFlipped, errors, entropyAccumulator, nbActionsInSequence);
}
}
if (ProgramParameters::errorAnalysis)
errors.printStats();
else
config.printAsOutput(stdout);
beam[0]->config.printAsOutput(stdout);
fprintf(stderr, " \n");
}
......
......@@ -46,12 +46,12 @@ class Classifier
/// In train mode, the underlying neural network's parameters will be zero-initialized, instead of being read from a file.
bool trainMode;
/// @brief The FeatureModel of this Classifier.
std::unique_ptr<FeatureModel> fm;
std::shared_ptr<FeatureModel> fm;
/// @brief The ActionSet of this Classifier.
std::unique_ptr<ActionSet> as;
std::shared_ptr<ActionSet> as;
/// @brief The neural network used by this Classifier.
/// The neural network is only used for Classifier of type Prediction.
std::unique_ptr<NeuralNetwork> nn;
std::shared_ptr<NeuralNetwork> nn;
/// @brief A string describing the topology of the underlying neural network.
std::string topology;
/// @brief The oracle being used by this Classifier.
......
......@@ -60,9 +60,9 @@ class TransitionMachine
/// @brief Whether or not the program is in train mode.
bool trainMode;
/// @brief Store the Classifier by their name.
std::map< std::string, std::unique_ptr<Classifier> > str2classifier;
std::map< std::string, std::shared_ptr<Classifier> > str2classifier;
/// @brief Store the State by their name.
std::map< std::string, std::unique_ptr<State> > str2state;
std::map< std::string, std::shared_ptr<State> > str2state;
/// @brief Pointer to the initial State.
State * initialState;
/// @brief Pointer to the current State.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment