diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index 4aa259b3a25d1cd1b6e7ba395433098158386b7f..6b14f7606c5d1729a4eb0c3d319b0b78112eb080 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -46,7 +46,7 @@ class Action static Action addHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis); static Action addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition); static Action pushWordIndexOnStack(); - static Action popStack(); + static Action popStack(int relIndex); static Action emptyStack(); static Action setRoot(int bufferIndex); static Action updateIds(int bufferIndex); @@ -60,6 +60,7 @@ class Action static Action setMultiwordIds(int multiwordSize); static Action split(int index); static Action setRootUpdateIdsEmptyStackIfSentChanged(); + static Action deprel(std::string value); }; #endif diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index a047a8a0b2327cf94adf28f550c1333caf2abf66..527fd23fefcb25387f9396e7d8be5e89731fef0e 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -45,6 +45,7 @@ class Config std::vector<String> lines; std::set<std::string> predicted; int lastPoppedStack{-1}; + int lastAttached{-1}; int currentWordId{0}; std::vector<Transition *> appliableSplitTransitions; std::vector<int> appliableTransitions; @@ -88,6 +89,7 @@ class Config const String & getAsFeature(int colIndex, int lineIndex) const; ValueIterator getIterator(int colIndex, int lineIndex, int hypothesisIndex); ConstValueIterator getConstIterator(int colIndex, int lineIndex, int hypothesisIndex) const; + std::size_t & getStackRef(int relativeIndex); long getRelativeWordIndex(int relativeIndex) const; @@ -111,6 +113,7 @@ class Config void addToHistory(const std::string & transition); void addToStack(std::size_t index); void popStack(); + void swapStack(int relIndex1, int relIndex2); bool isComment(std::size_t lineIndex) const; bool isCommentPredicted(std::size_t lineIndex) const; bool isMultiword(std::size_t lineIndex) const; @@ -133,6 +136,7 @@ class Config bool hasRelativeWordIndex(Object object, int relativeIndex) const; const String & getHistory(int relativeIndex) const; std::size_t getStack(int relativeIndex) const; + std::size_t getStackSize() const; bool hasHistory(int relativeIndex) const; bool hasStack(int relativeIndex) const; String getState() const; @@ -141,6 +145,8 @@ class Config void addPredicted(const std::set<std::string> & predicted); bool isPredicted(const std::string & colName) const; int getLastPoppedStack() const; + int getLastAttached() const; + void setLastAttached(int lastAttached); int getCurrentWordId() const; void setCurrentWordId(int currentWordId); void addMissingColumns(); diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index 487f34e93450df3ccaec5dce8cd21b2cc58e29ce..422cbdd15af21b3c5ffeb60b70df410f755524a2 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -16,12 +16,26 @@ class Transition private : + static int getNbLinkedWith(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config); + static int getNbLinkedWithHead(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config); + static int getNbLinkedWithDeps(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config); + + static int getFirstIndexOfSentence(int baseIndex, const Config & config); + static int getLastIndexOfSentence(int baseIndex, const Config & config); + void initWrite(std::string colName, std::string object, std::string index, std::string value); void initAdd(std::string colName, std::string object, std::string index, std::string value); - void initShift(); - void initLeft(std::string label); - void initRight(std::string label); - void initReduce(); + void initEagerShift(); + void initStandardShift(); + void initEagerLeft_rel(std::string label); + void initEagerRight_rel(std::string label); + void initStandardLeft_rel(std::string label); + void initStandardRight_rel(std::string label); + void initDeprel(std::string label); + void initEagerLeft(); + void initEagerRight(); + void initReduce_strict(); + void initReduce_relaxed(); void initEOS(int bufferIndex); void initNothing(); void initIgnoreChar(); diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index 9095082485b2f840222c3552cea0c810c40b8904..7c496eebd632060e82d56be30ee0617d7e7f55e1 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -268,23 +268,27 @@ Action Action::pushWordIndexOnStack() return {Type::Push, apply, undo, appliable}; } -Action Action::popStack() +Action Action::popStack(int relIndex) { - auto apply = [](Config & config, Action & a) + auto apply = [relIndex](Config & config, Action & a) { - auto toSave = config.getStack(0); + auto toSave = config.getStack(relIndex); a.data.push_back(std::to_string(toSave)); + for (int i = 0; relIndex-1-i >= 0; i++) + config.swapStack(relIndex-i, relIndex-1-i); config.popStack(); }; - auto undo = [](Config & config, Action & a) + auto undo = [relIndex](Config & config, Action & a) { config.addToStack(std::stoi(a.data.back())); + for (int i = 0; i+1 <= relIndex; i++) + config.swapStack(i, i+1); }; - auto appliable = [](const Config & config, const Action &) + auto appliable = [relIndex](const Config & config, const Action &) { - return config.hasStack(0) and config.getStack(0) != config.getWordIndex(); + return config.hasStack(relIndex) and config.getStack(relIndex) != config.getWordIndex(); }; return {Type::Pop, apply, undo, appliable}; @@ -380,8 +384,18 @@ Action Action::assertIsEmpty(const std::string & colName, Config::Object object, auto appliable = [colName, object, relativeIndex](const Config & config, const Action &) { - auto lineIndex = config.getRelativeWordIndex(object, relativeIndex); - return util::isEmpty(config.getAsFeature(colName, lineIndex)); + try + { + if (!config.hasRelativeWordIndex(object, relativeIndex)) + return false; + auto lineIndex = config.getRelativeWordIndex(object, relativeIndex); + return util::isEmpty(config.getAsFeature(colName, lineIndex)); + } catch (std::exception & e) + { + util::myThrow(fmt::format("colName='{}' object='{}' relativeIndex='{}' {}", colName, object == Config::Object::Stack ? "Stack" : "Buffer", relativeIndex, e.what())); + } + + return false; }; return {Type::Check, apply, undo, appliable}; @@ -399,8 +413,18 @@ Action Action::assertIsNotEmpty(const std::string & colName, Config::Object obje auto appliable = [colName, object, relativeIndex](const Config & config, const Action &) { - auto lineIndex = config.getRelativeWordIndex(object, relativeIndex); - return !util::isEmpty(config.getAsFeature(colName, lineIndex)); + try + { + if (!config.hasRelativeWordIndex(object, relativeIndex)) + return false; + auto lineIndex = config.getRelativeWordIndex(object, relativeIndex); + return !util::isEmpty(config.getAsFeature(colName, lineIndex)); + } catch (std::exception & e) + { + util::myThrow(fmt::format("colName='{}' object='{}' relativeIndex='{}' {}", colName, object == Config::Object::Stack ? "Stack" : "Buffer", relativeIndex, e.what())); + } + + return false; }; return {Type::Check, apply, undo, appliable}; @@ -577,23 +601,28 @@ Action Action::attach(Config::Object governorObject, int governorIndex, Config:: { auto apply = [governorObject, governorIndex, dependentObject, dependentIndex](Config & config, Action & a) { - long lineIndex = config.getRelativeWordIndex(governorObject, governorIndex); + long govIndex = config.getRelativeWordIndex(governorObject, governorIndex); long depIndex = config.getRelativeWordIndex(dependentObject, dependentIndex); - addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, std::to_string(lineIndex)).apply(config, a); + addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, std::to_string(govIndex)).apply(config, a); addToHypothesisRelative(Config::childsColName, governorObject, governorIndex, std::to_string(depIndex)).apply(config, a); + a.data.emplace_back(std::to_string(config.getLastAttached())); + config.setLastAttached(depIndex); }; auto undo = [governorObject, governorIndex, dependentObject, dependentIndex](Config & config, Action & a) { addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, "").undo(config, a); addToHypothesisRelative(Config::childsColName, governorObject, governorIndex, "").apply(config, a); + config.setLastAttached(std::stoi(a.data.back())); + a.data.pop_back(); }; auto appliable = [governorObject, governorIndex, dependentObject, dependentIndex](const Config & config, const Action & action) { if (!config.hasRelativeWordIndex(governorObject, governorIndex) or !config.hasRelativeWordIndex(dependentObject, dependentIndex)) return false; + long govLineIndex = config.getRelativeWordIndex(governorObject, governorIndex); long depLineIndex = config.getRelativeWordIndex(dependentObject, dependentIndex); @@ -604,6 +633,10 @@ Action Action::attach(Config::Object governorObject, int governorIndex, Config:: if (config.getAsFeature(Config::sentIdColName, govLineIndex) != config.getAsFeature(Config::sentIdColName, depLineIndex)) return false; + // Check if dep is not already attached + if (!util::isEmpty(config.getAsFeature(Config::headColName, depLineIndex))) + return false; + // Check for cycles while (govLineIndex != depLineIndex) { @@ -740,3 +773,23 @@ Action Action::setRootUpdateIdsEmptyStackIfSentChanged() return {Type::Write, apply, undo, appliable}; } +Action Action::deprel(std::string value) +{ + auto apply = [value](Config & config, Action & a) + { + addHypothesis(Config::deprelColName, config.getLastAttached(), value).apply(config, a); + }; + + auto undo = [](Config & config, Action & a) + { + addHypothesis(Config::deprelColName, config.getLastAttached(), "").undo(config, a); + }; + + auto appliable = [](const Config & config, const Action & action) + { + return config.has(0,config.getLastAttached(),0); + }; + + return {Type::Write, apply, undo, appliable}; +} + diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index f68baf3b4a56b739da01aa8d50383133e16b4536..320686c1564b25c82176ddfde95dfb41e8495c91 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -357,6 +357,13 @@ void Config::popStack() stack.pop_back(); } +void Config::swapStack(int relIndex1, int relIndex2) +{ + int tmp = getStack(relIndex1); + getStackRef(relIndex1) = getStack(relIndex2); + getStackRef(relIndex2) = tmp; +} + bool Config::hasCharacter(int letterIndex) const { return letterIndex >= 0 and letterIndex < (int)util::getSize(rawInput); @@ -523,6 +530,13 @@ const Config::String & Config::getHistory(int relativeIndex) const } std::size_t Config::getStack(int relativeIndex) const +{ + if (relativeIndex == -1) + return getLastPoppedStack(); + return stack[stack.size()-1-relativeIndex]; +} + +std::size_t & Config::getStackRef(int relativeIndex) { return stack[stack.size()-1-relativeIndex]; } @@ -534,6 +548,8 @@ bool Config::hasHistory(int relativeIndex) const bool Config::hasStack(int relativeIndex) const { + if (relativeIndex == -1) + return has(0,getLastPoppedStack(),0); return relativeIndex >= 0 && relativeIndex < (int)stack.size(); } @@ -696,3 +712,18 @@ bool Config::isExtraColumn(const std::string & colName) const return false; } +int Config::getLastAttached() const +{ + return lastAttached; +} + +void Config::setLastAttached(int lastAttached) +{ + this->lastAttached = lastAttached; +} + +std::size_t Config::getStackSize() const +{ + return stack.size(); +} + diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index a159f39a6863f3cf061ff26afc823cf1fca68641..1a2ad12bdad2d496fd4449e25914bc0b54212013 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -9,14 +9,28 @@ Transition::Transition(const std::string & name) [this](auto sm){(initWrite(sm[3], sm[1], sm[2], sm[4]));}}, {std::regex("ADD ([bs])\\.(.+) (.+) (.+)"), [this](auto sm){(initAdd(sm[3], sm[1], sm[2], sm[4]));}}, - {std::regex("SHIFT"), - [this](auto){initShift();}}, - {std::regex("REDUCE"), - [this](auto){initReduce();}}, - {std::regex("LEFT (.+)"), - [this](auto sm){(initLeft(sm[1]));}}, - {std::regex("RIGHT (.+)"), - [this](auto sm){(initRight(sm[1]));}}, + {std::regex("eager_SHIFT"), + [this](auto){initEagerShift();}}, + {std::regex("standard_SHIFT"), + [this](auto){initStandardShift();}}, + {std::regex("REDUCE_strict"), + [this](auto){initReduce_strict();}}, + {std::regex("REDUCE_relaxed"), + [this](auto){initReduce_relaxed();}}, + {std::regex("eager_LEFT_rel (.+)"), + [this](auto sm){(initEagerLeft_rel(sm[1]));}}, + {std::regex("eager_RIGHT_rel (.+)"), + [this](auto sm){(initEagerRight_rel(sm[1]));}}, + {std::regex("standard_LEFT_rel (.+)"), + [this](auto sm){(initStandardLeft_rel(sm[1]));}}, + {std::regex("standard_RIGHT_rel (.+)"), + [this](auto sm){(initStandardRight_rel(sm[1]));}}, + {std::regex("eager_LEFT"), + [this](auto){(initEagerLeft());}}, + {std::regex("eager_RIGHT"), + [this](auto){(initEagerRight());}}, + {std::regex("deprel (.+)"), + [this](auto sm){(initDeprel(sm[1]));}}, {std::regex("EOS b\\.(.+)"), [this](auto sm){initEOS(std::stoi(sm[1]));}}, {std::regex("NOTHING"), @@ -67,19 +81,28 @@ void Transition::apply(Config & config) bool Transition::appliable(const Config & config) const { - if (!state.empty() && state != config.getState()) - return false; - - for (const Action & action : sequence) - if (!action.appliable(config, action)) + try + { + if (!state.empty() && state != config.getState()) return false; + for (const Action & action : sequence) + if (!action.appliable(config, action)) + return false; + } catch (std::exception & e) + { + util::myThrow(fmt::format("transition '{}' {}", name, e.what())); + } + return true; } int Transition::getCost(const Config & config) const { - return cost(config); + try {return cost(config);} + catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what()));} + + return 0; } const std::string & Transition::getName() const @@ -225,81 +248,101 @@ void Transition::initSplit(int index) }; } -void Transition::initShift() +void Transition::initEagerShift() { sequence.emplace_back(Action::pushWordIndexOnStack()); sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged()); cost = [](const Config & config) { - if (config.hasStack(0) && config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1) - return std::numeric_limits<int>::max(); - if (!config.isToken(config.getWordIndex())) return 0; - auto headGovIndex = config.getConst(Config::headColName, config.getWordIndex(), 0); - - int cost = 0; - for (int i = 0; config.hasStack(i); ++i) - { - if (!config.has(0, config.getStack(i), 0)) - continue; - - auto stackIndex = config.getStack(i); - auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); + return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config); + }; +} - if (stackGovIndex == std::to_string(config.getWordIndex()) || headGovIndex == std::to_string(stackIndex)) - ++cost; - } +void Transition::initStandardShift() +{ + sequence.emplace_back(Action::pushWordIndexOnStack()); + sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged()); - return cost; + cost = [](const Config & config) + { + return 0; }; } -void Transition::initLeft(std::string label) +void Transition::initEagerLeft_rel(std::string label) { sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0)); sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label)); - sequence.emplace_back(Action::popStack()); + sequence.emplace_back(Action::popStack(0)); cost = [label](const Config & config) { - auto stackIndex = config.getStack(0); - auto wordIndex = config.getWordIndex(); - if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) - return std::numeric_limits<int>::max(); + auto depIndex = config.getStack(0); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); + auto govIndex = config.getWordIndex(); - int cost = 0; + if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0)) + return 0; + + int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config); - auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); + if (label != config.getConst(Config::deprelColName, depIndex, 0)) + ++cost; - for (int i = wordIndex+1; config.has(0, i, 0); ++i) - { - if (!config.isToken(i)) - continue; + return cost; + }; +} - if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) - break; +void Transition::initStandardLeft_rel(std::string label) +{ + sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Stack, 1)); + sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 1, label)); + sequence.emplace_back(Action::popStack(1)); - auto otherGovIndex = config.getConst(Config::headColName, i, 0); + cost = [label](const Config & config) + { + auto depIndex = config.getStack(1); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); + auto govIndex = config.getStack(0); - if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex)) - ++cost; - } + if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0)) + return 0; - //TODO : Check if this is necessary - if (stackGovIndex != std::to_string(wordIndex)) - ++cost; + int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, depIndex, config); + cost += getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config); - if (label != config.getConst(Config::deprelColName, stackIndex, 0)) + if (label != config.getConst(Config::deprelColName, depIndex, 0)) ++cost; return cost; }; } -void Transition::initRight(std::string label) +void Transition::initEagerLeft() +{ + sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0)); + sequence.emplace_back(Action::popStack(0)); + + cost = [](const Config & config) + { + auto depIndex = config.getStack(0); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); + auto govIndex = config.getWordIndex(); + + if (depGovIndex == std::to_string(govIndex)) + return 0; + + int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config); + + return cost; + }; +} + +void Transition::initEagerRight_rel(std::string label) { sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0)); sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label)); @@ -308,83 +351,102 @@ void Transition::initRight(std::string label) cost = [label](const Config & config) { - auto stackIndex = config.getStack(0); - auto wordIndex = config.getWordIndex(); - if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) - return std::numeric_limits<int>::max(); + auto govIndex = config.getStack(0); + auto depIndex = config.getWordIndex(); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); - int cost = 0; - - auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0); - - for (int i = wordIndex; config.has(0, i, 0); ++i) - { - if (!config.isToken(i)) - continue; + if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0)) + return 0; - auto otherGovIndex = config.getConst(Config::headColName, i, 0); + int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config); + cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config); - if (bufferGovIndex == std::to_string(i)) - ++cost; + if (label != config.getConst(Config::deprelColName, depIndex, 0)) + ++cost; - if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) - break; - } + return cost; + }; +} - for (int i = 1; config.hasStack(i); ++i) - { - if (!config.has(0, config.getStack(i), 0)) - continue; +void Transition::initStandardRight_rel(std::string label) +{ + sequence.emplace_back(Action::attach(Config::Object::Stack, 1, Config::Object::Stack, 0)); + sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label)); + sequence.emplace_back(Action::popStack(0)); - auto otherStackIndex = config.getStack(i); - auto otherStackGov = config.getConst(Config::headColName, otherStackIndex, 0); + cost = [label](const Config & config) + { + auto govIndex = config.getStack(1); + auto depIndex = config.getStack(0); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); - if (otherStackGov == std::to_string(wordIndex) || bufferGovIndex == std::to_string(otherStackIndex)) - ++cost; - } + if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0)) + return 0; - //TODO : Check if this is necessary - if (bufferGovIndex != std::to_string(stackIndex)) - ++cost; + int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, depIndex, config); + cost += getNbLinkedWith(2, config.getStackSize()-1, Config::Object::Stack, depIndex, config); - if (label != config.getConst(Config::deprelColName, wordIndex, 0)) + if (label != config.getConst(Config::deprelColName, depIndex, 0)) ++cost; return cost; }; } -void Transition::initReduce() +void Transition::initEagerRight() { - sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0)); - sequence.emplace_back(Action::popStack()); + sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0)); + sequence.emplace_back(Action::pushWordIndexOnStack()); + sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged()); cost = [](const Config & config) { - if (!config.isToken(config.getStack(0))) + auto govIndex = config.getStack(0); + auto depIndex = config.getWordIndex(); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); + + if (depGovIndex == std::to_string(govIndex)) return 0; - int cost = 0; + int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config); + cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config); + return cost; + }; +} + +void Transition::initReduce_strict() +{ + sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0)); + sequence.emplace_back(Action::popStack(0)); + + cost = [](const Config & config) + { auto stackIndex = config.getStack(0); - auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); + auto wordIndex = config.getWordIndex(); - for (int i = config.getWordIndex(); config.has(0, i, 0); ++i) - { - if (!config.isToken(i)) - continue; + if (!config.isToken(stackIndex)) + return 0; - auto otherGovIndex = config.getConst(Config::headColName, i, 0); + int cost = getNbLinkedWithDeps(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config); - if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex)) - ++cost; + return cost; + }; +} - if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) - break; - } +void Transition::initReduce_relaxed() +{ + sequence.emplace_back(Action::popStack(0)); - if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1) - ++cost; + cost = [](const Config & config) + { + auto stackIndex = config.getStack(0); + auto wordIndex = config.getWordIndex(); + + if (!config.isToken(stackIndex)) + return 0; + + int cost = getNbLinkedWith(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config); return cost; }; @@ -407,3 +469,123 @@ void Transition::initEOS(int bufferIndex) }; } +void Transition::initDeprel(std::string label) +{ + sequence.emplace_back(Action::deprel(label)); + + cost = [label](const Config & config) + { + return config.getConst(Config::deprelColName, config.getLastAttached(), 0) == label ? 0 : 1; + }; +} + +int Transition::getNbLinkedWith(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config) +{ + auto govIndex = config.getConst(Config::headColName, withIndex, 0); + auto govIndexPredicted = config.getAsFeature(Config::headColName, withIndex); + + int nbLinkedWith = 0; + + for (int i = firstIndex; i <= lastIndex; ++i) + { + int index = i; + if (object == Config::Object::Stack) + index = config.getStack(i); + + if (!config.isToken(index)) + continue; + + auto otherGovIndex = config.getConst(Config::headColName, index, 0); + auto otherGovIndexPredicted = config.getAsFeature(Config::headColName, index); + + if (govIndex == std::to_string(index) and util::isEmpty(govIndexPredicted)) + ++nbLinkedWith; + if (otherGovIndex == std::to_string(withIndex) and util::isEmpty(otherGovIndexPredicted)) + ++nbLinkedWith; + } + + return nbLinkedWith; +} + +int Transition::getNbLinkedWithHead(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config) +{ + auto govIndex = config.getConst(Config::headColName, withIndex, 0); + auto govIndexPredicted = config.getAsFeature(Config::headColName, withIndex); + + int nbLinkedWith = 0; + + for (int i = firstIndex; i <= lastIndex; ++i) + { + int index = i; + if (object == Config::Object::Stack) + index = config.getStack(i); + + if (!config.isToken(index)) + continue; + + if (govIndex == std::to_string(index) and util::isEmpty(govIndexPredicted)) + ++nbLinkedWith; + } + + return nbLinkedWith; +} + +int Transition::getNbLinkedWithDeps(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config) +{ + int nbLinkedWith = 0; + + for (int i = firstIndex; i <= lastIndex; ++i) + { + int index = i; + if (object == Config::Object::Stack) + index = config.getStack(i); + + if (!config.isToken(index)) + continue; + + auto otherGovIndex = config.getConst(Config::headColName, index, 0); + auto otherGovIndexPredicted = config.getAsFeature(Config::headColName, index); + + if (otherGovIndex == std::to_string(withIndex) and util::isEmpty(otherGovIndexPredicted)) + ++nbLinkedWith; + } + + return nbLinkedWith; +} + +int Transition::getFirstIndexOfSentence(int baseIndex, const Config & config) +{ + int firstIndex = baseIndex; + + for (int i = baseIndex; config.has(0, i, 0); --i) + { + if (!config.isToken(i)) + continue; + + if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) + break; + + firstIndex = i; + } + + return firstIndex; +} + +int Transition::getLastIndexOfSentence(int baseIndex, const Config & config) +{ + int lastIndex = baseIndex; + + for (int i = baseIndex; config.has(0, i, 0); ++i) + { + if (!config.isToken(i)) + continue; + + lastIndex = i; + + if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) + break; + } + + return lastIndex; +} +