diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index 4aa259b3a25d1cd1b6e7ba395433098158386b7f..2ac6dde2348ed8f471e1dc925251fa66d9c6446e 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -60,6 +60,7 @@ class Action static Action setMultiwordIds(int multiwordSize); static Action split(int index); static Action setRootUpdateIdsEmptyStackIfSentChanged(); + static Action deprel(std::string value); }; #endif diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index a047a8a0b2327cf94adf28f550c1333caf2abf66..c847fef06d1b0984a8f026ac7188da9abf388387 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -45,6 +45,7 @@ class Config std::vector<String> lines; std::set<std::string> predicted; int lastPoppedStack{-1}; + int lastAttached{-1}; int currentWordId{0}; std::vector<Transition *> appliableSplitTransitions; std::vector<int> appliableTransitions; @@ -141,6 +142,8 @@ class Config void addPredicted(const std::set<std::string> & predicted); bool isPredicted(const std::string & colName) const; int getLastPoppedStack() const; + int getLastAttached() const; + void setLastAttached(int lastAttached); int getCurrentWordId() const; void setCurrentWordId(int currentWordId); void addMissingColumns(); diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index 487f34e93450df3ccaec5dce8cd21b2cc58e29ce..e574f9db0ac35e748a818c4e57db6c44c1430a07 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -19,8 +19,11 @@ class Transition void initWrite(std::string colName, std::string object, std::string index, std::string value); void initAdd(std::string colName, std::string object, std::string index, std::string value); void initShift(); - void initLeft(std::string label); - void initRight(std::string label); + void initEagerLeft_rel(std::string label); + void initEagerRight_rel(std::string label); + void initDeprel(std::string label); + void initEagerLeft(); + void initEagerRight(); void initReduce(); void initEOS(int bufferIndex); void initNothing(); diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index 9095082485b2f840222c3552cea0c810c40b8904..44a9cdb212445aee1ab71ccaa2ce17a7be4a82b4 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -582,12 +582,16 @@ Action Action::attach(Config::Object governorObject, int governorIndex, Config:: addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, std::to_string(lineIndex)).apply(config, a); addToHypothesisRelative(Config::childsColName, governorObject, governorIndex, std::to_string(depIndex)).apply(config, a); + a.data.emplace_back(std::to_string(config.getLastAttached())); + config.setLastAttached(depIndex); }; auto undo = [governorObject, governorIndex, dependentObject, dependentIndex](Config & config, Action & a) { addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, "").undo(config, a); addToHypothesisRelative(Config::childsColName, governorObject, governorIndex, "").apply(config, a); + config.setLastAttached(std::stoi(a.data.back())); + a.data.pop_back(); }; auto appliable = [governorObject, governorIndex, dependentObject, dependentIndex](const Config & config, const Action & action) @@ -740,3 +744,23 @@ Action Action::setRootUpdateIdsEmptyStackIfSentChanged() return {Type::Write, apply, undo, appliable}; } +Action Action::deprel(std::string value) +{ + auto apply = [value](Config & config, Action & a) + { + addHypothesis(Config::deprelColName, config.getLastAttached(), value).apply(config, a); + }; + + auto undo = [](Config & config, Action & a) + { + addHypothesis(Config::deprelColName, config.getLastAttached(), "").undo(config, a); + }; + + auto appliable = [](const Config & config, const Action & action) + { + return config.has(0,config.getLastAttached(),0); + }; + + return {Type::Write, apply, undo, appliable}; +} + diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index f68baf3b4a56b739da01aa8d50383133e16b4536..048b28dd8554cc800427fa33de9f054019456197 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -524,6 +524,8 @@ const Config::String & Config::getHistory(int relativeIndex) const std::size_t Config::getStack(int relativeIndex) const { + if (relativeIndex == -1) + return getLastPoppedStack(); return stack[stack.size()-1-relativeIndex]; } @@ -534,6 +536,8 @@ bool Config::hasHistory(int relativeIndex) const bool Config::hasStack(int relativeIndex) const { + if (relativeIndex == -1) + return has(0,getLastPoppedStack(),0); return relativeIndex >= 0 && relativeIndex < (int)stack.size(); } @@ -696,3 +700,13 @@ bool Config::isExtraColumn(const std::string & colName) const return false; } +int Config::getLastAttached() const +{ + return lastAttached; +} + +void Config::setLastAttached(int lastAttached) +{ + this->lastAttached = lastAttached; +} + diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index a159f39a6863f3cf061ff26afc823cf1fca68641..6fac88efc436c7a19e2c6d931c181b23daa05146 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -13,10 +13,16 @@ Transition::Transition(const std::string & name) [this](auto){initShift();}}, {std::regex("REDUCE"), [this](auto){initReduce();}}, - {std::regex("LEFT (.+)"), - [this](auto sm){(initLeft(sm[1]));}}, - {std::regex("RIGHT (.+)"), - [this](auto sm){(initRight(sm[1]));}}, + {std::regex("eager_LEFT_rel (.+)"), + [this](auto sm){(initEagerLeft_rel(sm[1]));}}, + {std::regex("eager_RIGHT_rel (.+)"), + [this](auto sm){(initEagerRight_rel(sm[1]));}}, + {std::regex("eager_LEFT"), + [this](auto){(initEagerLeft());}}, + {std::regex("eager_RIGHT"), + [this](auto){(initEagerRight());}}, + {std::regex("deprel (.+)"), + [this](auto sm){(initDeprel(sm[1]));}}, {std::regex("EOS b\\.(.+)"), [this](auto sm){initEOS(std::stoi(sm[1]));}}, {std::regex("NOTHING"), @@ -257,7 +263,7 @@ void Transition::initShift() }; } -void Transition::initLeft(std::string label) +void Transition::initEagerLeft_rel(std::string label) { sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0)); sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label)); @@ -299,7 +305,45 @@ void Transition::initLeft(std::string label) }; } -void Transition::initRight(std::string label) +void Transition::initEagerLeft() +{ + sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0)); + sequence.emplace_back(Action::popStack()); + + cost = [](const Config & config) + { + auto stackIndex = config.getStack(0); + auto wordIndex = config.getWordIndex(); + if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) + return std::numeric_limits<int>::max(); + + int cost = 0; + + auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); + + for (int i = wordIndex+1; config.has(0, i, 0); ++i) + { + if (!config.isToken(i)) + continue; + + if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) + break; + + auto otherGovIndex = config.getConst(Config::headColName, i, 0); + + if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex)) + ++cost; + } + + //TODO : Check if this is necessary + if (stackGovIndex != std::to_string(wordIndex)) + ++cost; + + return cost; + }; +} + +void Transition::initEagerRight_rel(std::string label) { sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0)); sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label)); @@ -354,6 +398,57 @@ void Transition::initRight(std::string label) }; } +void Transition::initEagerRight() +{ + sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0)); + sequence.emplace_back(Action::pushWordIndexOnStack()); + sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged()); + + cost = [](const Config & config) + { + auto stackIndex = config.getStack(0); + auto wordIndex = config.getWordIndex(); + if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) + return std::numeric_limits<int>::max(); + + int cost = 0; + + auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0); + + for (int i = wordIndex; config.has(0, i, 0); ++i) + { + if (!config.isToken(i)) + continue; + + auto otherGovIndex = config.getConst(Config::headColName, i, 0); + + if (bufferGovIndex == std::to_string(i)) + ++cost; + + if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1) + break; + } + + for (int i = 1; config.hasStack(i); ++i) + { + if (!config.has(0, config.getStack(i), 0)) + continue; + + auto otherStackIndex = config.getStack(i); + auto otherStackGov = config.getConst(Config::headColName, otherStackIndex, 0); + + if (otherStackGov == std::to_string(wordIndex) || bufferGovIndex == std::to_string(otherStackIndex)) + ++cost; + } + + //TODO : Check if this is necessary + if (bufferGovIndex != std::to_string(stackIndex)) + ++cost; + + return cost; + }; +} + void Transition::initReduce() { sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0)); @@ -407,3 +502,13 @@ void Transition::initEOS(int bufferIndex) }; } +void Transition::initDeprel(std::string label) +{ + sequence.emplace_back(Action::deprel(label)); + + cost = [label](const Config & config) + { + return config.getConst(Config::deprelColName, config.getLastAttached(), 0) == label ? 0 : 1; + }; +} + diff --git a/reading_machine/src/TransitionSet.cpp b/reading_machine/src/TransitionSet.cpp index 5701c70a46a9dabc04cbe53b8058ff5676ce62f7..66da993267f524521836eaaa42358bd1de33a486 100644 --- a/reading_machine/src/TransitionSet.cpp +++ b/reading_machine/src/TransitionSet.cpp @@ -102,6 +102,25 @@ Transition * TransitionSet::getBestAppliableTransition(const Config & c) } } + if (!result) + { + for (unsigned int i = 0; i < transitions.size(); i++) + { + fmt::print(stderr, "{}\n", transitions[i].getName()); + if (!transitions[i].appliable(c)) + { + fmt::print(stderr, "not appliable\n"); + continue; + } + + fmt::print(stderr, "appliable\n"); + + int cost = transitions[i].getCost(c); + + fmt::print(stderr, "cost {}\n", cost); + } + } + return result; }