diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 70eba0ef0219ad1a1ae3e9d4c820600b7fcc3f5a..5394280fb1bc9b05145df876a6118d26d2f21043 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -38,9 +38,10 @@ std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float baseConfig = beam[0].config; - if (machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0") and baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1) + if (baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1) { - machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0")->apply(baseConfig); + auto eosTransition = Transition("EOS b.0"); + eosTransition.apply(baseConfig); if (debug) { fmt::print(stderr, "Forcing EOS transition\n"); diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index 2fc63ce295dae575b4cb1d0d9e881f8f1639da4e..d978112494d37d1224826901858e3102cc4edb5d 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -631,8 +631,12 @@ Action Action::setRoot(int bufferIndex) { int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex); int rootIndex = -1; + int searchStartIndex = lineIndex; + if (searchStartIndex > 0 and config.getAsFeature(Config::idColName, lineIndex) != "1") + searchStartIndex--; + int firstSentIndex = lineIndex; - for (int i = lineIndex; true; --i) + for (int i = searchStartIndex; true; --i) { if (!config.has(0, i, 0)) { @@ -646,6 +650,14 @@ Action Action::setRoot(int bufferIndex) if (config.getAsFeature(Config::EOSColName, i) == Config::EOSSymbol1) break; + firstSentIndex = i; + } + + for (int i = lineIndex; i >= firstSentIndex; --i) + { + if (!config.isTokenPredicted(i)) + continue; + if (std::string(config.getAsFeature(Config::headColName, i)).empty()) { rootIndex = i; @@ -653,20 +665,11 @@ Action Action::setRoot(int bufferIndex) } } - for (int i = lineIndex; true; --i) + for (int i = lineIndex; i >= firstSentIndex; --i) { - if (!config.has(0, i, 0)) - { - if (i < 0) - break; - util::myThrow("The current sentence is too long to be completly held by the data strucure. Consider increasing SubConfig::SpanSize"); - } if (!config.isTokenPredicted(i)) continue; - if (config.getAsFeature(Config::EOSColName, i) == Config::EOSSymbol1) - break; - if (std::string(config.getAsFeature(Config::headColName, i)).empty()) { if (i == rootIndex)