diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index 4e070db55e93466b7d8a9282c6c594bbf9ec50bd..324acb3e76a2af1291885680247eecf212c6043f 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -55,7 +55,10 @@ Action Action::setMultiwordIds(int multiwordSize) { addHypothesisRelative(Config::idColName, Object::Buffer, 0, fmt::format("{}-{}", config.getCurrentWordId()+1, config.getCurrentWordId()+multiwordSize)).apply(config, a); for (int i = 0; i < multiwordSize; i++) + { addHypothesisRelative(Config::idColName, Object::Buffer, i+1, fmt::format("{}", config.getCurrentWordId()+1+i)).apply(config, a); + addHypothesisRelative(Config::isMultiColName, Object::Buffer, i+1, Config::EOSSymbol1).apply(config, a); + } }; auto undo = [multiwordSize](Config & config, Action &) @@ -255,14 +258,22 @@ Action Action::addHypothesisRelative(const std::string & colName, Object object, Action Action::pushWordIndexOnStack() { - auto apply = [](Config & config, Action &) + auto apply = [](Config & config, Action & a) { - config.addToStack(config.getWordIndex()); + if (config.isTokenPredicted(config.getWordIndex())) + { + a.data.emplace_back(); + config.addToStack(config.getWordIndex()); + } }; - auto undo = [](Config & config, Action &) + auto undo = [](Config & config, Action & a) { - config.popStack(); + if (!a.data.empty()) + { + config.popStack(); + a.data.pop_back(); + } }; auto appliable = [](const Config & config, const Action &) @@ -320,7 +331,12 @@ Action Action::endWord() auto appliable = [](const Config & config, const Action &) { - return !util::isEmpty(config.getAsFeature("FORM", config.getWordIndex())); + if (util::isEmpty(config.getAsFeature("FORM", config.getWordIndex()))) + return false; + if (!util::isEmpty(config.getAsFeature(Config::idColName, config.getWordIndex())) and config.getAsFeature(Config::isMultiColName, config.getWordIndex()) != Config::EOSSymbol1) + return false; + + return true; }; return {Type::Write, apply, undo, appliable}; @@ -441,7 +457,7 @@ Action Action::setRoot() break; util::myThrow("The current sentence is too long to be completly held by the data strucure. Consider increasing SubConfig::SpanSize"); } - if (!config.isToken(i)) + if (!config.isTokenPredicted(i)) continue; if (config.getLastNotEmptyHypConst(Config::EOSColName, i) == Config::EOSSymbol1) @@ -464,7 +480,7 @@ Action Action::setRoot() break; util::myThrow("The current sentence is too long to be completly held by the data strucure. Consider increasing SubConfig::SpanSize"); } - if (!config.isToken(i)) + if (!config.isTokenPredicted(i)) continue; if (config.getLastNotEmptyHypConst(Config::EOSColName, i) == Config::EOSSymbol1) @@ -541,8 +557,8 @@ Action Action::updateIds() if (config.getLastNotEmptyHypConst(Config::EOSColName, i) == Config::EOSSymbol1) break; - if (config.isMultiword(i)) - config.getFirstEmpty(Config::idColName, i) = fmt::format("{}-{}", currentId, currentId+config.getMultiwordSize(i)); + if (config.isMultiwordPredicted(i)) + config.getFirstEmpty(Config::idColName, i) = fmt::format("{}-{}", currentId, currentId+config.getMultiwordSizePredicted(i)); else config.getFirstEmpty(Config::idColName, i) = fmt::format("{}", currentId++); } diff --git a/reading_machine/src/Strategy.cpp b/reading_machine/src/Strategy.cpp index a13ac5e284e002ab991429d81fb0da8354a91c1a..769420b74b2ec107de14aff819e92ce564900815 100644 --- a/reading_machine/src/Strategy.cpp +++ b/reading_machine/src/Strategy.cpp @@ -39,7 +39,6 @@ Strategy::Strategy(const std::vector<std::string_view> & lines) if (edges.empty()) util::myThrow("Strategy is empty"); - defaultCycle.pop_back(); std::reverse(defaultCycle.begin(), defaultCycle.end()); originalDefaultCycle = defaultCycle; } @@ -51,8 +50,15 @@ std::pair<std::string, int> Strategy::getMovement(const Config & c, const std::s if (c.stateIsDone()) isDone[c.getState()] = true; - while (defaultCycle.size() && isDone[defaultCycle.back()]) - defaultCycle.pop_back(); + for (unsigned int i = 0; i < defaultCycle.size(); i++) + { + if (isDone[defaultCycle[i]]) + { + while (defaultCycle.size() != i) + defaultCycle.pop_back(); + break; + } + } if (type == Type::Sequential) return getMovementSequential(c, transitionPrefix); @@ -113,6 +119,9 @@ std::pair<std::string, int> Strategy::getMovementIncremental(const Config & c, c if (target.empty()) util::myThrow(fmt::format("no suitable movement found for current state '{}' and transition '{}'", c.getState(), transition)); + if (c.hasStack(0) and c.getStack(0) == c.getWordIndex() and not c.canMoveWordIndex(movement)) + target = c.getState(); + if (!isDone[target]) return {target, c.canMoveWordIndex(movement) ? movement : 0}; diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 3d27545c00bff1741fe4021a300da2623341bbf4..b0fe02ca3cbed7756a445c4a83108b5738c48935 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -330,7 +330,7 @@ void Transition::initRight(std::string label) auto otherGovIndex = config.getConst(Config::headColName, i, 0); - if (bufferGovIndex == std::to_string(i) || otherGovIndex == std::to_string(wordIndex)) + if (bufferGovIndex == std::to_string(i)) ++cost; if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) @@ -366,16 +366,13 @@ void Transition::initReduce() cost = [](const Config & config) { - if (!config.has(0, config.getStack(0), 0)) - return 0; - if (!config.isToken(config.getStack(0))) return 0; int cost = 0; auto stackIndex = config.getStack(0); - auto stackGovIndex = config.getConst(Config::headColName, config.getStack(0), 0); + auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); for (int i = config.getWordIndex(); config.has(0, i, 0); ++i) {