diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index aeee6b4fe6a64be6f63640a72ef21dd411bf194c..47ecca46eb8d5916ebc7d4494cad950bdb12e2a4 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -116,29 +116,66 @@ void Transition::apply(Config & config, float entropy) curValue = fmt::format("_"); } } - if (config.hasColIndex("ENTROPY")) + // Entropy of the action that attach a word to the tree + if (config.hasColIndex("ENTROPY_ATTACH")) + { + bool mean = false; + if (name.find("SHIFT") == std::string::npos and name.find("REDUCE") == std::string::npos) + { + if (name.find("LEFT") != std::string::npos) + { + auto action = Action::sumToHypothesis("ENTROPY_ATTACH", config.getStack(0), entropy, mean); + action.apply(config, action); + } + else + { + auto action = Action::sumToHypothesis("ENTROPY_ATTACH", config.getWordIndex(), entropy, mean); + action.apply(config, action); + } + } + } + if (config.hasColIndex("ENTROPY_ATTACH_MEAN")) { bool mean = true; if (name.find("SHIFT") == std::string::npos and name.find("REDUCE") == std::string::npos) { if (name.find("LEFT") != std::string::npos) { - auto action = Action::sumToHypothesis("ENTROPY", config.getStack(0), entropy, mean); + auto action = Action::sumToHypothesis("ENTROPY_ATTACH_MEAN", config.getStack(0), entropy, mean); action.apply(config, action); } else { - auto action = Action::sumToHypothesis("ENTROPY", config.getWordIndex(), entropy, mean); + auto action = Action::sumToHypothesis("ENTROPY_ATTACH_MEAN", config.getWordIndex(), entropy, mean); action.apply(config, action); } } } + // Entropy of every action is taken into account + if (config.hasColIndex("ENTROPY_ALL")) + { + bool mean = false; + auto action = Action::sumToHypothesis("ENTROPY_ALL", config.getWordIndex(), entropy, mean); + action.apply(config, action); + } + if (config.hasColIndex("ENTROPY_ALL_MEAN")) + { + bool mean = true; + auto action = Action::sumToHypothesis("ENTROPY_ALL_MEAN", config.getWordIndex(), entropy, mean); + action.apply(config, action); + } if (config.hasColIndex("SURPRISAL")) { float surprisal = -log(config.getChosenActionScore()); auto action = Action::sumToHypothesis("SURPRISAL", config.getWordIndex(), surprisal, false); action.apply(config, action); } + if (config.hasColIndex("SURPRISAL_MEAN")) + { + float surprisal = -log(config.getChosenActionScore()); + auto action = Action::sumToHypothesis("SURPRISAL_MEAN", config.getWordIndex(), surprisal, true); + action.apply(config, action); + } apply(config); }