diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index 2ac6dde2348ed8f471e1dc925251fa66d9c6446e..6b14f7606c5d1729a4eb0c3d319b0b78112eb080 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -46,7 +46,7 @@ class Action static Action addHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis); static Action addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition); static Action pushWordIndexOnStack(); - static Action popStack(); + static Action popStack(int relIndex); static Action emptyStack(); static Action setRoot(int bufferIndex); static Action updateIds(int bufferIndex); diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index c847fef06d1b0984a8f026ac7188da9abf388387..527fd23fefcb25387f9396e7d8be5e89731fef0e 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -89,6 +89,7 @@ class Config const String & getAsFeature(int colIndex, int lineIndex) const; ValueIterator getIterator(int colIndex, int lineIndex, int hypothesisIndex); ConstValueIterator getConstIterator(int colIndex, int lineIndex, int hypothesisIndex) const; + std::size_t & getStackRef(int relativeIndex); long getRelativeWordIndex(int relativeIndex) const; @@ -112,6 +113,7 @@ class Config void addToHistory(const std::string & transition); void addToStack(std::size_t index); void popStack(); + void swapStack(int relIndex1, int relIndex2); bool isComment(std::size_t lineIndex) const; bool isCommentPredicted(std::size_t lineIndex) const; bool isMultiword(std::size_t lineIndex) const; @@ -134,6 +136,7 @@ class Config bool hasRelativeWordIndex(Object object, int relativeIndex) const; const String & getHistory(int relativeIndex) const; std::size_t getStack(int relativeIndex) const; + std::size_t getStackSize() const; bool hasHistory(int relativeIndex) const; bool hasStack(int relativeIndex) const; String getState() const; diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index bedf5de4cbaa634c2a139c95e844cc57d6053c5c..2114b28eb3f5d9a8720c5e1a30b6628f2f713116 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -16,6 +16,11 @@ class Transition private : + static int getNbLinkedWith(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config); + + static int getFirstIndexOfSentence(int baseIndex, const Config & config); + static int getLastIndexOfSentence(int baseIndex, const Config & config); + void initWrite(std::string colName, std::string object, std::string index, std::string value); void initAdd(std::string colName, std::string object, std::string index, std::string value); void initShift(); diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index 44a9cdb212445aee1ab71ccaa2ce17a7be4a82b4..7bcca61716f38eb6f7ca61ec015ec849a1161500 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -268,23 +268,27 @@ Action Action::pushWordIndexOnStack() return {Type::Push, apply, undo, appliable}; } -Action Action::popStack() +Action Action::popStack(int relIndex) { - auto apply = [](Config & config, Action & a) + auto apply = [relIndex](Config & config, Action & a) { - auto toSave = config.getStack(0); + auto toSave = config.getStack(relIndex); a.data.push_back(std::to_string(toSave)); + for (int i = 0; relIndex-1-i >= 0; i++) + config.swapStack(relIndex-i, relIndex-1-i); config.popStack(); }; - auto undo = [](Config & config, Action & a) + auto undo = [relIndex](Config & config, Action & a) { config.addToStack(std::stoi(a.data.back())); + for (int i = 0; i+1 <= relIndex; i++) + config.swapStack(i, i+1); }; - auto appliable = [](const Config & config, const Action &) + auto appliable = [relIndex](const Config & config, const Action &) { - return config.hasStack(0) and config.getStack(0) != config.getWordIndex(); + return config.hasStack(relIndex) and config.getStack(relIndex) != config.getWordIndex(); }; return {Type::Pop, apply, undo, appliable}; diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 048b28dd8554cc800427fa33de9f054019456197..493e89e0c5e264cff7232028038a2e314ebe69f2 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -357,6 +357,13 @@ void Config::popStack() stack.pop_back(); } +void Config::swapStack(int relIndex1, int relIndex2) +{ + int tmp = getStack(relIndex1); + getStackRef(relIndex1) = relIndex2; + getStackRef(relIndex2) = tmp; +} + bool Config::hasCharacter(int letterIndex) const { return letterIndex >= 0 and letterIndex < (int)util::getSize(rawInput); @@ -529,6 +536,11 @@ std::size_t Config::getStack(int relativeIndex) const return stack[stack.size()-1-relativeIndex]; } +std::size_t & Config::getStackRef(int relativeIndex) +{ + return stack[stack.size()-1-relativeIndex]; +} + bool Config::hasHistory(int relativeIndex) const { return relativeIndex >= 0 && relativeIndex < (int)history.size(); @@ -710,3 +722,8 @@ void Config::setLastAttached(int lastAttached) this->lastAttached = lastAttached; } +std::size_t Config::getStackSize() const +{ + return stack.size(); +} + diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 7bc41350cd0457f1c0caa8ad0ab0c671771b36e8..28aae30c29ac3ab7f19c034e3b29c95d3a8d7d17 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -246,22 +246,7 @@ void Transition::initShift() if (!config.isToken(config.getWordIndex())) return 0; - auto headGovIndex = config.getConst(Config::headColName, config.getWordIndex(), 0); - - int cost = 0; - for (int i = 0; config.hasStack(i); ++i) - { - if (!config.has(0, config.getStack(i), 0)) - continue; - - auto stackIndex = config.getStack(i); - auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); - - if (stackGovIndex == std::to_string(config.getWordIndex()) || headGovIndex == std::to_string(stackIndex)) - ++cost; - } - - return cost; + return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config); }; } @@ -269,34 +254,18 @@ void Transition::initEagerLeft_rel(std::string label) { sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0)); sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label)); - sequence.emplace_back(Action::popStack()); + sequence.emplace_back(Action::popStack(0)); cost = [label](const Config & config) { auto stackIndex = config.getStack(0); + auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); auto wordIndex = config.getWordIndex(); if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) return std::numeric_limits<int>::max(); - int cost = 0; - - auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); - - for (int i = wordIndex+1; config.has(0, i, 0); ++i) - { - if (!config.isToken(i)) - continue; + int cost = getNbLinkedWith(wordIndex+1, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config); - if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) - break; - - auto otherGovIndex = config.getConst(Config::headColName, i, 0); - - if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex)) - ++cost; - } - - //TODO : Check if this is necessary if (stackGovIndex != std::to_string(wordIndex)) ++cost; @@ -310,34 +279,18 @@ void Transition::initEagerLeft_rel(std::string label) void Transition::initEagerLeft() { sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0)); - sequence.emplace_back(Action::popStack()); + sequence.emplace_back(Action::popStack(0)); cost = [](const Config & config) { auto stackIndex = config.getStack(0); + auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); auto wordIndex = config.getWordIndex(); if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) return std::numeric_limits<int>::max(); - int cost = 0; - - auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); - - for (int i = wordIndex+1; config.has(0, i, 0); ++i) - { - if (!config.isToken(i)) - continue; - - if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) - break; - - auto otherGovIndex = config.getConst(Config::headColName, i, 0); - - if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex)) - ++cost; - } + int cost = getNbLinkedWith(wordIndex+1, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config); - //TODO : Check if this is necessary if (stackGovIndex != std::to_string(wordIndex)) ++cost; @@ -359,37 +312,13 @@ void Transition::initEagerRight_rel(std::string label) if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) return std::numeric_limits<int>::max(); - int cost = 0; - auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0); - for (int i = wordIndex; config.has(0, i, 0); ++i) - { - if (!config.isToken(i)) - continue; - - auto otherGovIndex = config.getConst(Config::headColName, i, 0); - - if (bufferGovIndex == std::to_string(i)) - ++cost; - - if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) - break; - } - - for (int i = 1; config.hasStack(i); ++i) - { - if (!config.has(0, config.getStack(i), 0)) - continue; - - auto otherStackIndex = config.getStack(i); - auto otherStackGov = config.getConst(Config::headColName, otherStackIndex, 0); + if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1) + return std::numeric_limits<int>::max(); - if (otherStackGov == std::to_string(wordIndex) || bufferGovIndex == std::to_string(otherStackIndex)) - ++cost; - } + int cost = getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config); - //TODO : Check if this is necessary if (bufferGovIndex != std::to_string(stackIndex)) ++cost; @@ -413,37 +342,13 @@ void Transition::initEagerRight() if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) return std::numeric_limits<int>::max(); - int cost = 0; - auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0); - for (int i = wordIndex; config.has(0, i, 0); ++i) - { - if (!config.isToken(i)) - continue; - - auto otherGovIndex = config.getConst(Config::headColName, i, 0); - - if (bufferGovIndex == std::to_string(i)) - ++cost; - - if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) - break; - } - - for (int i = 1; config.hasStack(i); ++i) - { - if (!config.has(0, config.getStack(i), 0)) - continue; - - auto otherStackIndex = config.getStack(i); - auto otherStackGov = config.getConst(Config::headColName, otherStackIndex, 0); + if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1) + return std::numeric_limits<int>::max(); - if (otherStackGov == std::to_string(wordIndex) || bufferGovIndex == std::to_string(otherStackIndex)) - ++cost; - } + int cost = getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config); - //TODO : Check if this is necessary if (bufferGovIndex != std::to_string(stackIndex)) ++cost; @@ -454,33 +359,19 @@ void Transition::initEagerRight() void Transition::initReduce_strict() { sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0)); - sequence.emplace_back(Action::popStack()); + sequence.emplace_back(Action::popStack(0)); cost = [](const Config & config) { - if (!config.isToken(config.getStack(0))) - return 0; - - int cost = 0; - auto stackIndex = config.getStack(0); - auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); - - for (int i = config.getWordIndex(); config.has(0, i, 0); ++i) - { - if (!config.isToken(i)) - continue; - - auto otherGovIndex = config.getConst(Config::headColName, i, 0); + auto wordIndex = config.getWordIndex(); - if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex)) - ++cost; + if (!config.isToken(stackIndex)) + return 0; - if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) - break; - } + int cost = getNbLinkedWith(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config); - if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1) + if (config.getConst(Config::EOSColName, stackIndex, 0) == Config::EOSSymbol1) ++cost; return cost; @@ -489,33 +380,19 @@ void Transition::initReduce_strict() void Transition::initReduce_relaxed() { - sequence.emplace_back(Action::popStack()); + sequence.emplace_back(Action::popStack(0)); cost = [](const Config & config) { - if (!config.isToken(config.getStack(0))) - return 0; - - int cost = 0; - auto stackIndex = config.getStack(0); - auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); - - for (int i = config.getWordIndex(); config.has(0, i, 0); ++i) - { - if (!config.isToken(i)) - continue; - - auto otherGovIndex = config.getConst(Config::headColName, i, 0); + auto wordIndex = config.getWordIndex(); - if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex)) - ++cost; + if (!config.isToken(stackIndex)) + return 0; - if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) - break; - } + int cost = getNbLinkedWith(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config); - if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1) + if (config.getConst(Config::EOSColName, stackIndex, 0) == Config::EOSSymbol1) ++cost; return cost; @@ -549,3 +426,63 @@ void Transition::initDeprel(std::string label) }; } +int Transition::getNbLinkedWith(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config) +{ + auto govIndex = config.getConst(Config::headColName, withIndex, 0); + + int nbLinkedWith = 0; + + for (int i = firstIndex; i <= lastIndex; ++i) + { + int index = i; + if (object == Config::Object::Stack) + index = config.getStack(i); + + if (!config.isToken(index)) + continue; + + auto otherGovIndex = config.getConst(Config::headColName, index, 0); + + if (govIndex == std::to_string(index) || otherGovIndex == std::to_string(withIndex)) + ++nbLinkedWith; + } + + return nbLinkedWith; +} + +int Transition::getFirstIndexOfSentence(int baseIndex, const Config & config) +{ + int firstIndex = baseIndex; + + for (int i = baseIndex; config.has(0, i, 0); --i) + { + if (!config.isToken(i)) + continue; + + if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) + break; + + firstIndex = i; + } + + return firstIndex; +} + +int Transition::getLastIndexOfSentence(int baseIndex, const Config & config) +{ + int lastIndex = baseIndex; + + for (int i = baseIndex; config.has(0, i, 0); ++i) + { + if (!config.isToken(i)) + continue; + + lastIndex = i; + + if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) + break; + } + + return lastIndex; +} +