Skip to content
Snippets Groups Projects
Commit f2ddeb2b authored by Franck Dary's avatar Franck Dary
Browse files

Added the entropy feature

parent 371b3e02
No related branches found
No related tags found
No related merge requests found
......@@ -68,11 +68,14 @@ void Decoder::decode()
config.moveHead(transition->headMvt);
float entropy = Classifier::computeEntropy(weightedActions);
config.addToEntropyHistory(entropy);
if (ProgramParameters::printEntropy)
{
nbActionsInSequence++;
entropyAccumulator += Classifier::computeEntropy(weightedActions);
entropyAccumulator += entropy;
if (config.head >= 1 && config.getTape(ProgramParameters::sequenceDelimiterTape)[config.head-1] != ProgramParameters::sequenceDelimiter)
justFlipped = false;
......
......@@ -88,11 +88,14 @@ std::map<std::string, std::pair<float, std::pair<float, float> > > Trainer::getS
tm.takeTransition(transition);
devConfig->moveHead(transition->headMvt);
float entropy = Classifier::computeEntropy(weightedActions);
devConfig->addToEntropyHistory(entropy);
if (ProgramParameters::printEntropy)
{
nbActionsInSequence++;
entropyAccumulator += Classifier::computeEntropy(weightedActions);
entropyAccumulator += entropy;
if (devConfig->head >= 1 && devConfig->getTape(ProgramParameters::sequenceDelimiterTape)[devConfig->head-1] != ProgramParameters::sequenceDelimiter)
justFlipped = false;
......@@ -267,6 +270,9 @@ void Trainer::train()
TransitionMachine::Transition * transition = tm.getTransition(actionName);
tm.takeTransition(transition);
trainConfig.moveHead(transition->headMvt);
float entropy = Classifier::computeEntropy(weightedActions);
trainConfig.addToEntropyHistory(entropy);
}
nbSteps++;
......
......@@ -122,6 +122,10 @@ class Config
///
/// @return The history of Action of the current state in the TransitionMachine.
std::vector<std::string> & getCurrentStateHistory();
/// @brief Get the history of entropies of the current state in the TransitionMachine.
///
/// @return The history of entropies of the current state in the TransitionMachine.
std::vector<float> & getCurrentStateEntropyHistory();
/// @brief Shuffle the segments of the Config.
///
/// For instance if you call shuffle("EOS", "1");\n
......
......@@ -71,6 +71,14 @@ class FeatureBank
///
/// @return The prefix of the name of the requested Action.
static FeatureModel::FeatureValue actionHistory(Config & config, int index, const std::string & featName);
/// @brief Get a previous entropy in the history of the current state.
///
/// @param config The Config to work with.
/// @param index The relative index of the entropy (e.g. -1 for the entropy of the Action choice that happened before the last Action performed).
/// @param featName The name of this feature.
///
/// @return The discretized value of the desired entropy.
static FeatureModel::FeatureValue entropyHistory(Config & config, int index, const std::string & featName);
/// @brief Get the content of a cell of a tape.
///
/// @param config The Config to work with.
......
......@@ -249,6 +249,11 @@ std::vector<std::string> & Config::getCurrentStateHistory()
return actionHistory[getCurrentStateName()];
}
std::vector<float> & Config::getCurrentStateEntropyHistory()
{
return entropyHistory[getCurrentStateName()];
}
void Config::shuffle(const std::string & delimiterTape, const std::string & delimiter)
{
auto & tape = getTape(delimiterTape);
......
......@@ -139,6 +139,9 @@ std::function<FeatureModel::FeatureValue(Config &)> FeatureBank::str2func(const
if(infos == "tc")
return [index, s](Config & c)
{return actionHistory(c, index, s);};
else if (infos == "entropy")
return [index, s](Config & c)
{return entropyHistory(c, index, s);};
else
{
fprintf(stderr, "ERROR (%s) : unknown feature \'%s\' Aborting.\n", ERRINFO, s.c_str());
......@@ -166,6 +169,20 @@ FeatureModel::FeatureValue FeatureBank::actionHistory(Config & config, int index
return {dict, &featName, dict->getStr(history[history.size()-1-index]), policy};
}
FeatureModel::FeatureValue FeatureBank::entropyHistory(Config & config, int index, const std::string & featName)
{
Dict * dict = Dict::getDict("entropy");
auto policy = dictPolicy2FeaturePolicy(dict->policy);
auto & history = config.getCurrentStateEntropyHistory();
if(index < 0 || index >= (int)history.size())
return {dict, &featName, &Dict::nullValueStr, policy};
std::string value = std::to_string((int)history[history.size()-1-index]);
return {dict, &featName, dict->getStr(value), policy};
}
FeatureModel::FeatureValue FeatureBank::ldep(Config & config, int index, const std::string & object, const std::string & tapeName, const std::string & featName)
{
auto & tape = config.getTape(tapeName);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment