diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index a9b1517ab557f68317acd234407bf1136ad7f22d..49d3bd794a57028289f31e26389a866f5934c053 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -44,6 +44,7 @@ class Action static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis); static Action addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition); static Action sumToHypothesis(const std::string & colName, std::size_t lineIndex, float addition, bool mean); + static Action maxWithHypothesis(const std::string & colName, std::size_t lineIndex, float addition); static Action addHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis); static Action addHypothesisRelativeRelaxed(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis); static Action addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition); diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index d978112494d37d1224826901858e3102cc4edb5d..085a05acd6007f64ffbd117187a40af74eca1ee8 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -295,6 +295,63 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde return {Type::Write, apply, undo, appliable}; } +Action Action::maxWithHypothesis(const std::string & colName, std::size_t lineIndex, float addition) +{ + auto apply = [colName, lineIndex, addition](Config & config, Action &) + { + std::string totalStr = std::string(config.getLastNotEmptyHypConst(colName, lineIndex)); + + if (totalStr.empty() || totalStr == "_") + totalStr = fmt::format("{}={}|{}", std::string(config.getState()), 0.0, 0); + + auto byStates = util::split(totalStr, ','); + int index = -1; + for (unsigned int i = 0; i < byStates.size(); i++) + { + auto state = util::split(byStates[i], '=')[0]; + if (state == config.getState()) + { + index = i; + break; + } + } + if (index == -1) + { + byStates.emplace_back(fmt::format("{}={}|{}", std::string(config.getState()), 0.0, 0)); + index = byStates.size()-1; + } + + auto splited = util::split(util::split(byStates[index], '=')[1], '|'); + float curVal = 0.0; + int curNb = 0; + if (splited.size() == 2) + { + curVal = std::stof(splited[0]); + curNb = std::stoi(splited[1]); + } + + curNb += 1; + + curVal = addition; + + byStates[index] = fmt::format("{}={}|{}", std::string(config.getState()), curVal, curNb); + + config.getLastNotEmptyHyp(colName, lineIndex) = util::join(",", byStates); + }; + + auto undo = [](Config &, Action &) + { + //TODO: not done + }; + + auto appliable = [colName, lineIndex, addition](const Config & config, const Action &) + { + return config.has(colName, lineIndex, 0); + }; + + return {Type::Write, apply, undo, appliable}; +} + Action Action::addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition) { auto apply = [colName, object, relativeIndex, addition](Config & config, Action & a) diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index a50d38aa349c7d09656cdcf749ee856735acab9a..6a4f504286d56f56fe51a8773b6602cdb1eba371 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -116,67 +116,232 @@ void Transition::apply(Config & config, float entropy) curValue = fmt::format("_"); } } - // Entropy of the action that attach a word to the tree - if (config.hasColIndex("ENTROPY_ATTACH")) + + float surprisal = -log(config.getChosenActionScore()); + + if (config.hasColIndex("ENT_CUR_ADD")) + { + auto action = Action::sumToHypothesis("ENT_CUR_ADD", config.getWordIndex(), entropy, false); + action.apply(config, action); + } + if (config.hasColIndex("ENT_CUR_MEAN")) + { + auto action = Action::sumToHypothesis("ENT_CUR_MEAN", config.getWordIndex(), entropy, true); + action.apply(config, action); + } + if (config.hasColIndex("ENT_CUR_MAX")) + { + auto action = Action::maxWithHypothesis("ENT_CUR_MAX", config.getWordIndex(), entropy); + action.apply(config, action); + } + + if (config.hasColIndex("ENT_ATT_ADD")) { - bool mean = false; - if (name.find("SHIFT") == std::string::npos and name.find("REDUCE") == std::string::npos) + if (name.find("REDUCE") != std::string::npos) { - if (name.find("LEFT") != std::string::npos) - { - auto action = Action::sumToHypothesis("ENTROPY_ATTACH", config.getStack(0), entropy, mean); - action.apply(config, action); - } - else + auto action = Action::sumToHypothesis("ENT_ATT_ADD", config.getStack(0), entropy, false); + action.apply(config, action); + } + else + { + auto action = Action::sumToHypothesis("ENT_ATT_ADD", config.getWordIndex(), entropy, false); + action.apply(config, action); + if (name.find("LEFT") != std::string::npos || name.find("RIGHT") != std::string::npos) { - auto action = Action::sumToHypothesis("ENTROPY_ATTACH", config.getWordIndex(), entropy, mean); + auto action = Action::sumToHypothesis("ENT_ATT_ADD", config.getStack(0), entropy, false); action.apply(config, action); } } } - if (config.hasColIndex("ENTROPY_ATTACH_MEAN")) + if (config.hasColIndex("ENT_ATT_MEAN")) { - bool mean = true; - if (name.find("SHIFT") == std::string::npos and name.find("REDUCE") == std::string::npos) + if (name.find("REDUCE") != std::string::npos) { - if (name.find("LEFT") != std::string::npos) + auto action = Action::sumToHypothesis("ENT_ATT_MEAN", config.getStack(0), entropy, true); + action.apply(config, action); + } + else + { + auto action = Action::sumToHypothesis("ENT_ATT_MEAN", config.getWordIndex(), entropy, true); + action.apply(config, action); + if (name.find("LEFT") != std::string::npos || name.find("RIGHT") != std::string::npos) { - auto action = Action::sumToHypothesis("ENTROPY_ATTACH_MEAN", config.getStack(0), entropy, mean); + auto action = Action::sumToHypothesis("ENT_ATT_MEAN", config.getStack(0), entropy, true); action.apply(config, action); } - else + } + } + if (config.hasColIndex("ENT_ATT_MAX")) + { + if (name.find("REDUCE") != std::string::npos) + { + auto action = Action::maxWithHypothesis("ENT_ATT_MAX", config.getStack(0), entropy); + action.apply(config, action); + } + else + { + auto action = Action::maxWithHypothesis("ENT_ATT_MAX", config.getWordIndex(), entropy); + action.apply(config, action); + if (name.find("LEFT") != std::string::npos || name.find("RIGHT") != std::string::npos) { - auto action = Action::sumToHypothesis("ENTROPY_ATTACH_MEAN", config.getWordIndex(), entropy, mean); + auto action = Action::maxWithHypothesis("ENT_ATT_MAX", config.getStack(0), entropy); action.apply(config, action); } } } - // Entropy of every action is taken into account - if (config.hasColIndex("ENTROPY_ALL")) + + if (config.hasColIndex("ENT_TGT_ADD")) { - bool mean = false; - auto action = Action::sumToHypothesis("ENTROPY_ALL", config.getWordIndex(), entropy, mean); - action.apply(config, action); + if (name.find("REDUCE") != std::string::npos || name.find("LEFT") != std::string::npos) + { + auto action = Action::sumToHypothesis("ENT_TGT_ADD", config.getStack(0), entropy, false); + action.apply(config, action); + } + else + { + auto action = Action::sumToHypothesis("ENT_TGT_ADD", config.getWordIndex(), entropy, false); + action.apply(config, action); + } } - if (config.hasColIndex("ENTROPY_ALL_MEAN")) + if (config.hasColIndex("ENT_TGT_MEAN")) { - bool mean = true; - auto action = Action::sumToHypothesis("ENTROPY_ALL_MEAN", config.getWordIndex(), entropy, mean); + if (name.find("REDUCE") != std::string::npos || name.find("LEFT") != std::string::npos) + { + auto action = Action::sumToHypothesis("ENT_TGT_MEAN", config.getStack(0), entropy, true); + action.apply(config, action); + } + else + { + auto action = Action::sumToHypothesis("ENT_TGT_MEAN", config.getWordIndex(), entropy, true); + action.apply(config, action); + } + } + if (config.hasColIndex("ENT_TGT_MAX")) + { + if (name.find("REDUCE") != std::string::npos || name.find("LEFT") != std::string::npos) + { + auto action = Action::maxWithHypothesis("ENT_TGT_MAX", config.getStack(0), entropy); + action.apply(config, action); + } + else + { + auto action = Action::maxWithHypothesis("ENT_TGT_MAX", config.getWordIndex(), entropy); + action.apply(config, action); + } + } + + if (config.hasColIndex("SUR_CUR_ADD")) + { + auto action = Action::sumToHypothesis("SUR_CUR_ADD", config.getWordIndex(), surprisal, false); action.apply(config, action); } - if (config.hasColIndex("SURPRISAL")) + if (config.hasColIndex("SUR_CUR_MEAN")) { - float surprisal = -log(config.getChosenActionScore()); - auto action = Action::sumToHypothesis("SURPRISAL", config.getWordIndex(), surprisal, false); + auto action = Action::sumToHypothesis("SUR_CUR_MEAN", config.getWordIndex(), surprisal, true); action.apply(config, action); } - if (config.hasColIndex("SURPRISAL_MEAN")) + if (config.hasColIndex("SUR_CUR_MAX")) { - float surprisal = -log(config.getChosenActionScore()); - auto action = Action::sumToHypothesis("SURPRISAL_MEAN", config.getWordIndex(), surprisal, true); + auto action = Action::maxWithHypothesis("SUR_CUR_MAX", config.getWordIndex(), surprisal); action.apply(config, action); } + if (config.hasColIndex("SUR_ATT_ADD")) + { + if (name.find("REDUCE") != std::string::npos) + { + auto action = Action::sumToHypothesis("SUR_ATT_ADD", config.getStack(0), surprisal, false); + action.apply(config, action); + } + else + { + auto action = Action::sumToHypothesis("SUR_ATT_ADD", config.getWordIndex(), surprisal, false); + action.apply(config, action); + if (name.find("LEFT") != std::string::npos || name.find("RIGHT") != std::string::npos) + { + auto action = Action::sumToHypothesis("SUR_ATT_ADD", config.getStack(0), surprisal, false); + action.apply(config, action); + } + } + } + if (config.hasColIndex("SUR_ATT_MEAN")) + { + if (name.find("REDUCE") != std::string::npos) + { + auto action = Action::sumToHypothesis("SUR_ATT_MEAN", config.getStack(0), surprisal, true); + action.apply(config, action); + } + else + { + auto action = Action::sumToHypothesis("SUR_ATT_MEAN", config.getWordIndex(), surprisal, true); + action.apply(config, action); + if (name.find("LEFT") != std::string::npos || name.find("RIGHT") != std::string::npos) + { + auto action = Action::sumToHypothesis("SUR_ATT_MEAN", config.getStack(0), surprisal, true); + action.apply(config, action); + } + } + } + if (config.hasColIndex("SUR_ATT_MAX")) + { + if (name.find("REDUCE") != std::string::npos) + { + auto action = Action::maxWithHypothesis("SUR_ATT_MAX", config.getStack(0), surprisal); + action.apply(config, action); + } + else + { + auto action = Action::maxWithHypothesis("SUR_ATT_MAX", config.getWordIndex(), surprisal); + action.apply(config, action); + if (name.find("LEFT") != std::string::npos || name.find("RIGHT") != std::string::npos) + { + auto action = Action::maxWithHypothesis("SUR_ATT_MAX", config.getStack(0), surprisal); + action.apply(config, action); + } + } + } + + if (config.hasColIndex("SUR_TGT_ADD")) + { + if (name.find("REDUCE") != std::string::npos || name.find("LEFT") != std::string::npos) + { + auto action = Action::sumToHypothesis("SUR_TGT_ADD", config.getStack(0), surprisal, false); + action.apply(config, action); + } + else + { + auto action = Action::sumToHypothesis("SUR_TGT_ADD", config.getWordIndex(), surprisal, false); + action.apply(config, action); + } + } + if (config.hasColIndex("SUR_TGT_MEAN")) + { + if (name.find("REDUCE") != std::string::npos || name.find("LEFT") != std::string::npos) + { + auto action = Action::sumToHypothesis("SUR_TGT_MEAN", config.getStack(0), surprisal, true); + action.apply(config, action); + } + else + { + auto action = Action::sumToHypothesis("SUR_TGT_MEAN", config.getWordIndex(), surprisal, true); + action.apply(config, action); + } + } + if (config.hasColIndex("SUR_TGT_MAX")) + { + if (name.find("REDUCE") != std::string::npos || name.find("LEFT") != std::string::npos) + { + auto action = Action::maxWithHypothesis("SUR_TGT_MAX", config.getStack(0), surprisal); + action.apply(config, action); + } + else + { + auto action = Action::maxWithHypothesis("SUR_TGT_MAX", config.getWordIndex(), surprisal); + action.apply(config, action); + } + } + + apply(config); }