diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index 1ad7ab700c5c837ff83edac8043f9b804fecd37f..e376f500e244991f10cee0cb166c40a2c15d00fc 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -43,7 +43,7 @@ class Action static Action moveCharacterIndex(int movement); static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis); static Action addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition); - static Action sumToHypothesis(const std::string & colName, std::size_t lineIndex, float addition); + static Action sumToHypothesis(const std::string & colName, std::size_t lineIndex, float addition, bool mean); static Action addHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis); static Action addHypothesisRelativeRelaxed(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis); static Action addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition); diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index d99e4f4352bdda2bd7eecb45cb1493fb46149b35..5d4dedb43d3995e5e143d5d75f01835835fc5045 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -184,9 +184,9 @@ Action Action::addToHypothesis(const std::string & colName, std::size_t lineInde return {Type::Write, apply, undo, appliable}; } -Action Action::sumToHypothesis(const std::string & colName, std::size_t lineIndex, float addition) +Action Action::sumToHypothesis(const std::string & colName, std::size_t lineIndex, float addition, bool mean) { - auto apply = [colName, lineIndex, addition](Config & config, Action &) + auto apply = [colName, lineIndex, addition, mean](Config & config, Action &) { std::string totalStr = config.getLastNotEmptyHypConst(colName, lineIndex).get(); @@ -210,7 +210,6 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde index = byStates.size()-1; } - // Knuth’s algorithm for online mean auto splited = util::split(util::split(byStates[index], '=')[1], '|'); float curVal = 0.0; int curNb = 0; @@ -221,15 +220,24 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde } curNb += 1; - float delta = addition - curVal; - curVal += delta / curNb; + + if (mean) + { + // Knuth’s algorithm for online mean + float delta = addition - curVal; + curVal += delta / curNb; + } + else + { + curVal += addition; + } byStates[index] = fmt::format("{}={}|{}", config.getState(), curVal, curNb); config.getLastNotEmptyHyp(colName, lineIndex) = util::join(",", byStates); }; - auto undo = [colName, lineIndex, addition](Config & config, Action &) + auto undo = [colName, lineIndex, addition, mean](Config & config, Action &) { std::string totalStr = config.getLastNotEmptyHypConst(colName, lineIndex).get(); @@ -252,11 +260,11 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde byStates.emplace_back(fmt::format("{}={}|{}", config.getState(), 0.0, 0)); index = byStates.size()-1; } - - // Knuth’s algorithm for online mean + auto splited = util::split(util::split(byStates[index], '=')[1], '|'); float curVal = 0.0; int curNb = 0; + if (splited.size() == 2) { curVal = std::stof(splited[0]); @@ -264,7 +272,12 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde } curNb -= 1; - curVal = (curNb*curVal - addition) / (curNb - 1); + + // Knuth’s algorithm for online mean + if (mean) + curVal = (curNb*curVal - addition) / (curNb - 1); + else + curVal -= addition; byStates[index] = fmt::format("{}={}|{}", config.getState(), curVal, curNb); diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index dd83119650e8a53f5198c00e854b18cb6505734d..68d9dbbca5243ec03e03107a80d194fa1eb7d16b 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -101,20 +101,27 @@ void Transition::apply(Config & config, float entropy) { if (config.hasColIndex("ENTROPY")) { + bool mean = true; if (name.find("SHIFT") == std::string::npos and name.find("REDUCE") == std::string::npos) { if (name.find("LEFT") != std::string::npos) { - auto action = Action::sumToHypothesis("ENTROPY", config.getStack(0), entropy); + auto action = Action::sumToHypothesis("ENTROPY", config.getStack(0), entropy, mean); action.apply(config, action); } else { - auto action = Action::sumToHypothesis("ENTROPY", config.getWordIndex(), entropy); + auto action = Action::sumToHypothesis("ENTROPY", config.getWordIndex(), entropy, mean); action.apply(config, action); } } } + if (config.hasColIndex("SURPRISAL")) + { + float surprisal = -log(config.getChosenActionScore()); + auto action = Action::sumToHypothesis("SURPRISAL", config.getWordIndex(), surprisal, false); + action.apply(config, action); + } apply(config); }