From 8a5f62773edaf01e9091a2ab0ce4acbc7be04f98 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 17 Nov 2020 18:35:28 +0100 Subject: [PATCH] Surprisal for each word can now be computed by adding SURPRISAL to mcd --- reading_machine/include/Action.hpp | 2 +- reading_machine/src/Action.cpp | 31 +++++++++++++++++++++--------- reading_machine/src/Transition.cpp | 11 +++++++++-- 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index 1ad7ab7..e376f50 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -43,7 +43,7 @@ class Action static Action moveCharacterIndex(int movement); static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis); static Action addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition); - static Action sumToHypothesis(const std::string & colName, std::size_t lineIndex, float addition); + static Action sumToHypothesis(const std::string & colName, std::size_t lineIndex, float addition, bool mean); static Action addHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis); static Action addHypothesisRelativeRelaxed(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis); static Action addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition); diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index d99e4f4..5d4dedb 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -184,9 +184,9 @@ Action Action::addToHypothesis(const std::string & colName, std::size_t lineInde return {Type::Write, apply, undo, appliable}; } -Action Action::sumToHypothesis(const std::string & colName, std::size_t lineIndex, float addition) +Action Action::sumToHypothesis(const std::string & colName, std::size_t lineIndex, float addition, bool mean) { - auto apply = [colName, lineIndex, addition](Config & config, Action &) + auto apply = [colName, lineIndex, addition, mean](Config & config, Action &) { std::string totalStr = config.getLastNotEmptyHypConst(colName, lineIndex).get(); @@ -210,7 +210,6 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde index = byStates.size()-1; } - // Knuth’s algorithm for online mean auto splited = util::split(util::split(byStates[index], '=')[1], '|'); float curVal = 0.0; int curNb = 0; @@ -221,15 +220,24 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde } curNb += 1; - float delta = addition - curVal; - curVal += delta / curNb; + + if (mean) + { + // Knuth’s algorithm for online mean + float delta = addition - curVal; + curVal += delta / curNb; + } + else + { + curVal += addition; + } byStates[index] = fmt::format("{}={}|{}", config.getState(), curVal, curNb); config.getLastNotEmptyHyp(colName, lineIndex) = util::join(",", byStates); }; - auto undo = [colName, lineIndex, addition](Config & config, Action &) + auto undo = [colName, lineIndex, addition, mean](Config & config, Action &) { std::string totalStr = config.getLastNotEmptyHypConst(colName, lineIndex).get(); @@ -252,11 +260,11 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde byStates.emplace_back(fmt::format("{}={}|{}", config.getState(), 0.0, 0)); index = byStates.size()-1; } - - // Knuth’s algorithm for online mean + auto splited = util::split(util::split(byStates[index], '=')[1], '|'); float curVal = 0.0; int curNb = 0; + if (splited.size() == 2) { curVal = std::stof(splited[0]); @@ -264,7 +272,12 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde } curNb -= 1; - curVal = (curNb*curVal - addition) / (curNb - 1); + + // Knuth’s algorithm for online mean + if (mean) + curVal = (curNb*curVal - addition) / (curNb - 1); + else + curVal -= addition; byStates[index] = fmt::format("{}={}|{}", config.getState(), curVal, curNb); diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index dd83119..68d9dbb 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -101,20 +101,27 @@ void Transition::apply(Config & config, float entropy) { if (config.hasColIndex("ENTROPY")) { + bool mean = true; if (name.find("SHIFT") == std::string::npos and name.find("REDUCE") == std::string::npos) { if (name.find("LEFT") != std::string::npos) { - auto action = Action::sumToHypothesis("ENTROPY", config.getStack(0), entropy); + auto action = Action::sumToHypothesis("ENTROPY", config.getStack(0), entropy, mean); action.apply(config, action); } else { - auto action = Action::sumToHypothesis("ENTROPY", config.getWordIndex(), entropy); + auto action = Action::sumToHypothesis("ENTROPY", config.getWordIndex(), entropy, mean); action.apply(config, action); } } } + if (config.hasColIndex("SURPRISAL")) + { + float surprisal = -log(config.getChosenActionScore()); + auto action = Action::sumToHypothesis("SURPRISAL", config.getWordIndex(), surprisal, false); + action.apply(config, action); + } apply(config); } -- GitLab