Skip to content
Snippets Groups Projects
Commit cc2b7c14 authored by Franck Dary's avatar Franck Dary
Browse files

Added a new option for decode, printEntropy that gives confidence of the...

Added a new option for decode, printEntropy that gives confidence of the classifier about its predictions for each sequence
parent 0c5192e9
No related branches found
No related tags found
No related merge requests found
......@@ -18,8 +18,6 @@ class Decoder
TransitionMachine & tm;
/// @brief The current configuration of the TransitionMachine
Config & config;
/// @brief is true, decode will print infos on stderr
bool debugMode;
public :
......@@ -28,10 +26,8 @@ class Decoder
/// At the start of the function, bd must contain the input.\n
/// At the end of the function, bd will be terminal.
/// @param tm The trained TransitionMachine
/// @param bd The BD we need to fill
/// @param config The current configuration of the TransitionMachine
/// @param debugMode If true, infos will be printed on stderr.
Decoder(TransitionMachine & tm, Config & config, bool debugMode);
Decoder(TransitionMachine & tm, Config & config);
/// @brief Fill bd using tm.
void decode();
};
......
#include "Decoder.hpp"
#include "util.hpp"
Decoder::Decoder(TransitionMachine & tm, Config & config, bool debugMode)
Decoder::Decoder(TransitionMachine & tm, Config & config)
: tm(tm), config(config)
{
this->debugMode = debugMode;
}
void Decoder::decode()
{
float entropyAccumulator = 0.0;
int nbActionsInSequence = 0;
bool justFlipped = false;
while (!config.isFinal())
{
TransitionMachine::State * currentState = tm.getCurrentState();
......@@ -16,7 +18,7 @@ void Decoder::decode()
config.setCurrentStateName(&currentState->name);
Dict::currentClassifierName = classifier->name;
if (debugMode)
if (ProgramParameters::debug)
{
config.printForDebug(stderr);
fprintf(stderr, "State : \'%s\'\n", currentState->name.c_str());
......@@ -24,7 +26,8 @@ void Decoder::decode()
auto weightedActions = classifier->weightActions(config);
if (debugMode)
if (ProgramParameters::debug)
{
Classifier::printWeightedActions(stderr, weightedActions);
fprintf(stderr, "\n");
......@@ -62,7 +65,32 @@ void Decoder::decode()
TransitionMachine::Transition * transition = tm.getTransition(predictedAction);
tm.takeTransition(transition);
config.moveHead(transition->headMvt);
if (ProgramParameters::printEntropy)
{
nbActionsInSequence++;
for (unsigned int i = 0; i < 2 && i < weightedActions.size(); i++)
{
auto it = weightedActions.begin() + i;
entropyAccumulator -= it->second.first - (it->second.first - weightedActions[0].second.first);
}
if (config.head >= 1 && config.getTape(ProgramParameters::sequenceDelimiterTape)[config.head-1] != ProgramParameters::sequenceDelimiter)
justFlipped = false;
if ((config.head >= 1 && config.getTape(ProgramParameters::sequenceDelimiterTape)[config.head-1] == ProgramParameters::sequenceDelimiter && !justFlipped))
{
justFlipped = true;
entropyAccumulator /= nbActionsInSequence;
nbActionsInSequence = 0;
fprintf(stderr, "Entropy : %.2f\n", entropyAccumulator);
entropyAccumulator = 0.0;
}
}
}
config.printAsOutput(stdout);
......
......@@ -37,6 +37,12 @@ po::options_description getOptionsDescription()
opt.add_options()
("help,h", "Produce this help message")
("debug,d", "Print infos on stderr")
("printEntropy", "Print entropy for each sequence")
("sequenceDelimiterTape", po::value<std::string>()->default_value("EOS"),
"The name of the buffer's tape that contains the delimiter token for a sequence")
("sequenceDelimiter", po::value<std::string>()->default_value("1"),
"The value of the token that act as a delimiter for sequences")
("lang", po::value<std::string>()->default_value("fr"),
"Language you are working with");
......@@ -100,7 +106,10 @@ int main(int argc, char * argv[])
ProgramParameters::input = vm["input"].as<std::string>();
ProgramParameters::mcdName = vm["mcd"].as<std::string>();
ProgramParameters::debug = vm.count("debug") == 0 ? false : true;
ProgramParameters::printEntropy = vm.count("printEntropy") == 0 ? false : true;
ProgramParameters::lang = vm["lang"].as<std::string>();
ProgramParameters::sequenceDelimiterTape = vm["sequenceDelimiterTape"].as<std::string>();
ProgramParameters::sequenceDelimiter = vm["sequenceDelimiter"].as<std::string>();
const char * MACAON_DIR = std::getenv("MACAON_DIR");
std::string slash = "/";
......@@ -116,7 +125,7 @@ int main(int argc, char * argv[])
Config config(bd);
config.readInput(ProgramParameters::input);
Decoder decoder(tapeMachine, config, ProgramParameters::debug);
Decoder decoder(tapeMachine, config);
decoder.decode();
......
......@@ -45,6 +45,9 @@ struct ProgramParameters
static int iterationSize;
static int nbTrain;
static bool randomEmbeddings;
static bool printEntropy;
static std::string sequenceDelimiterTape;
static std::string sequenceDelimiter;
private :
......
......@@ -38,5 +38,8 @@ int ProgramParameters::dynamicEpoch;
float ProgramParameters::dynamicProbability;
bool ProgramParameters::showFeatureRepresentation;
bool ProgramParameters::randomEmbeddings;
bool ProgramParameters::printEntropy;
int ProgramParameters::iterationSize;
int ProgramParameters::nbTrain;
std::string ProgramParameters::sequenceDelimiterTape;
std::string ProgramParameters::sequenceDelimiter;
......@@ -104,7 +104,7 @@ void Trainer::train()
trainConfig.reset();
if(ProgramParameters::shuffleExamples)
trainConfig.shuffle("EOS", "1");
trainConfig.shuffle(ProgramParameters::sequenceDelimiterTape, ProgramParameters::sequenceDelimiter);
for (auto & it : trainCounter)
it.second.first = it.second.second = 0;
......
......@@ -64,6 +64,10 @@ po::options_description getOptionsDescription()
"Is the shell interactive ? Display advancement informations")
("randomEmbeddings", po::value<bool>()->default_value(false),
"When activated, the embeddings will be randomly initialized")
("sequenceDelimiterTape", po::value<std::string>()->default_value("EOS"),
"The name of the buffer's tape that contains the delimiter token for a sequence")
("sequenceDelimiter", po::value<std::string>()->default_value("1"),
"The value of the token that act as a delimiter for sequences")
("shuffle", po::value<bool>()->default_value(true),
"Shuffle examples after each iteration");
......@@ -242,6 +246,8 @@ int main(int argc, char * argv[])
ProgramParameters::interactive = vm["interactive"].as<bool>();
ProgramParameters::shuffleExamples = vm["shuffle"].as<bool>();
ProgramParameters::randomEmbeddings = vm["randomEmbeddings"].as<bool>();
ProgramParameters::sequenceDelimiterTape = vm["sequenceDelimiterTape"].as<std::string>();
ProgramParameters::sequenceDelimiter = vm["sequenceDelimiter"].as<std::string>();
ProgramParameters::learningRate = vm["lr"].as<float>();
ProgramParameters::beta1 = vm["b1"].as<float>();
ProgramParameters::beta2 = vm["b2"].as<float>();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment