From 8dfbc6968e54e3d74f03a56bb984cbaac5b048c8 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 23 Mar 2021 14:04:50 +0100 Subject: [PATCH] Improved forcing EOS transition, usefull for lineByLine mode --- decoder/src/Decoder.cpp | 5 +++-- reading_machine/src/Action.cpp | 25 ++++++++++++++----------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 70eba0e..5394280 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -38,9 +38,10 @@ std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float baseConfig = beam[0].config; - if (machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0") and baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1) + if (baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1) { - machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0")->apply(baseConfig); + auto eosTransition = Transition("EOS b.0"); + eosTransition.apply(baseConfig); if (debug) { fmt::print(stderr, "Forcing EOS transition\n"); diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index 2fc63ce..d978112 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -631,8 +631,12 @@ Action Action::setRoot(int bufferIndex) { int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex); int rootIndex = -1; + int searchStartIndex = lineIndex; + if (searchStartIndex > 0 and config.getAsFeature(Config::idColName, lineIndex) != "1") + searchStartIndex--; + int firstSentIndex = lineIndex; - for (int i = lineIndex; true; --i) + for (int i = searchStartIndex; true; --i) { if (!config.has(0, i, 0)) { @@ -646,6 +650,14 @@ Action Action::setRoot(int bufferIndex) if (config.getAsFeature(Config::EOSColName, i) == Config::EOSSymbol1) break; + firstSentIndex = i; + } + + for (int i = lineIndex; i >= firstSentIndex; --i) + { + if (!config.isTokenPredicted(i)) + continue; + if (std::string(config.getAsFeature(Config::headColName, i)).empty()) { rootIndex = i; @@ -653,20 +665,11 @@ Action Action::setRoot(int bufferIndex) } } - for (int i = lineIndex; true; --i) + for (int i = lineIndex; i >= firstSentIndex; --i) { - if (!config.has(0, i, 0)) - { - if (i < 0) - break; - util::myThrow("The current sentence is too long to be completly held by the data strucure. Consider increasing SubConfig::SpanSize"); - } if (!config.isTokenPredicted(i)) continue; - if (config.getAsFeature(Config::EOSColName, i) == Config::EOSSymbol1) - break; - if (std::string(config.getAsFeature(Config::headColName, i)).empty()) { if (i == rootIndex) -- GitLab