Skip to content
Snippets Groups Projects
Commit 8a5f6277 authored by Franck Dary's avatar Franck Dary
Browse files

Surprisal for each word can now be computed by adding SURPRISAL to mcd

parent dbc8b501
Branches
No related tags found
No related merge requests found
...@@ -43,7 +43,7 @@ class Action ...@@ -43,7 +43,7 @@ class Action
static Action moveCharacterIndex(int movement); static Action moveCharacterIndex(int movement);
static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis); static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis);
static Action addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition); static Action addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition);
static Action sumToHypothesis(const std::string & colName, std::size_t lineIndex, float addition); static Action sumToHypothesis(const std::string & colName, std::size_t lineIndex, float addition, bool mean);
static Action addHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis); static Action addHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis);
static Action addHypothesisRelativeRelaxed(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis); static Action addHypothesisRelativeRelaxed(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis);
static Action addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition); static Action addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition);
......
...@@ -184,9 +184,9 @@ Action Action::addToHypothesis(const std::string & colName, std::size_t lineInde ...@@ -184,9 +184,9 @@ Action Action::addToHypothesis(const std::string & colName, std::size_t lineInde
return {Type::Write, apply, undo, appliable}; return {Type::Write, apply, undo, appliable};
} }
Action Action::sumToHypothesis(const std::string & colName, std::size_t lineIndex, float addition) Action Action::sumToHypothesis(const std::string & colName, std::size_t lineIndex, float addition, bool mean)
{ {
auto apply = [colName, lineIndex, addition](Config & config, Action &) auto apply = [colName, lineIndex, addition, mean](Config & config, Action &)
{ {
std::string totalStr = config.getLastNotEmptyHypConst(colName, lineIndex).get(); std::string totalStr = config.getLastNotEmptyHypConst(colName, lineIndex).get();
...@@ -210,7 +210,6 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde ...@@ -210,7 +210,6 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde
index = byStates.size()-1; index = byStates.size()-1;
} }
// Knuth’s algorithm for online mean
auto splited = util::split(util::split(byStates[index], '=')[1], '|'); auto splited = util::split(util::split(byStates[index], '=')[1], '|');
float curVal = 0.0; float curVal = 0.0;
int curNb = 0; int curNb = 0;
...@@ -221,15 +220,24 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde ...@@ -221,15 +220,24 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde
} }
curNb += 1; curNb += 1;
if (mean)
{
// Knuth’s algorithm for online mean
float delta = addition - curVal; float delta = addition - curVal;
curVal += delta / curNb; curVal += delta / curNb;
}
else
{
curVal += addition;
}
byStates[index] = fmt::format("{}={}|{}", config.getState(), curVal, curNb); byStates[index] = fmt::format("{}={}|{}", config.getState(), curVal, curNb);
config.getLastNotEmptyHyp(colName, lineIndex) = util::join(",", byStates); config.getLastNotEmptyHyp(colName, lineIndex) = util::join(",", byStates);
}; };
auto undo = [colName, lineIndex, addition](Config & config, Action &) auto undo = [colName, lineIndex, addition, mean](Config & config, Action &)
{ {
std::string totalStr = config.getLastNotEmptyHypConst(colName, lineIndex).get(); std::string totalStr = config.getLastNotEmptyHypConst(colName, lineIndex).get();
...@@ -253,10 +261,10 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde ...@@ -253,10 +261,10 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde
index = byStates.size()-1; index = byStates.size()-1;
} }
// Knuth’s algorithm for online mean
auto splited = util::split(util::split(byStates[index], '=')[1], '|'); auto splited = util::split(util::split(byStates[index], '=')[1], '|');
float curVal = 0.0; float curVal = 0.0;
int curNb = 0; int curNb = 0;
if (splited.size() == 2) if (splited.size() == 2)
{ {
curVal = std::stof(splited[0]); curVal = std::stof(splited[0]);
...@@ -264,7 +272,12 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde ...@@ -264,7 +272,12 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde
} }
curNb -= 1; curNb -= 1;
// Knuth’s algorithm for online mean
if (mean)
curVal = (curNb*curVal - addition) / (curNb - 1); curVal = (curNb*curVal - addition) / (curNb - 1);
else
curVal -= addition;
byStates[index] = fmt::format("{}={}|{}", config.getState(), curVal, curNb); byStates[index] = fmt::format("{}={}|{}", config.getState(), curVal, curNb);
......
...@@ -101,20 +101,27 @@ void Transition::apply(Config & config, float entropy) ...@@ -101,20 +101,27 @@ void Transition::apply(Config & config, float entropy)
{ {
if (config.hasColIndex("ENTROPY")) if (config.hasColIndex("ENTROPY"))
{ {
bool mean = true;
if (name.find("SHIFT") == std::string::npos and name.find("REDUCE") == std::string::npos) if (name.find("SHIFT") == std::string::npos and name.find("REDUCE") == std::string::npos)
{ {
if (name.find("LEFT") != std::string::npos) if (name.find("LEFT") != std::string::npos)
{ {
auto action = Action::sumToHypothesis("ENTROPY", config.getStack(0), entropy); auto action = Action::sumToHypothesis("ENTROPY", config.getStack(0), entropy, mean);
action.apply(config, action); action.apply(config, action);
} }
else else
{ {
auto action = Action::sumToHypothesis("ENTROPY", config.getWordIndex(), entropy); auto action = Action::sumToHypothesis("ENTROPY", config.getWordIndex(), entropy, mean);
action.apply(config, action); action.apply(config, action);
} }
} }
} }
if (config.hasColIndex("SURPRISAL"))
{
float surprisal = -log(config.getChosenActionScore());
auto action = Action::sumToHypothesis("SURPRISAL", config.getWordIndex(), surprisal, false);
action.apply(config, action);
}
apply(config); apply(config);
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment