diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index 2114b28eb3f5d9a8720c5e1a30b6628f2f713116..91cc04a384419560f24e64a7ccbc56249f4ced17 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -23,9 +23,12 @@ class Transition void initWrite(std::string colName, std::string object, std::string index, std::string value); void initAdd(std::string colName, std::string object, std::string index, std::string value); - void initShift(); + void initEagerShift(); + void initStandardShift(); void initEagerLeft_rel(std::string label); void initEagerRight_rel(std::string label); + void initStandardLeft_rel(std::string label); + void initStandardRight_rel(std::string label); void initDeprel(std::string label); void initEagerLeft(); void initEagerRight(); diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 493e89e0c5e264cff7232028038a2e314ebe69f2..320686c1564b25c82176ddfde95dfb41e8495c91 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -360,7 +360,7 @@ void Config::popStack() void Config::swapStack(int relIndex1, int relIndex2) { int tmp = getStack(relIndex1); - getStackRef(relIndex1) = relIndex2; + getStackRef(relIndex1) = getStack(relIndex2); getStackRef(relIndex2) = tmp; } diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 2dd88ee3210b8a1e2d173bacef5326d9180af572..c2f6c8d4bca1558f2309ced366f9b12c2de0ab6a 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -9,8 +9,10 @@ Transition::Transition(const std::string & name) [this](auto sm){(initWrite(sm[3], sm[1], sm[2], sm[4]));}}, {std::regex("ADD ([bs])\\.(.+) (.+) (.+)"), [this](auto sm){(initAdd(sm[3], sm[1], sm[2], sm[4]));}}, - {std::regex("SHIFT"), - [this](auto){initShift();}}, + {std::regex("eager_SHIFT"), + [this](auto){initEagerShift();}}, + {std::regex("standard_SHIFT"), + [this](auto){initStandardShift();}}, {std::regex("REDUCE_strict"), [this](auto){initReduce_strict();}}, {std::regex("REDUCE_relaxed"), @@ -19,6 +21,10 @@ Transition::Transition(const std::string & name) [this](auto sm){(initEagerLeft_rel(sm[1]));}}, {std::regex("eager_RIGHT_rel (.+)"), [this](auto sm){(initEagerRight_rel(sm[1]));}}, + {std::regex("standard_LEFT_rel (.+)"), + [this](auto sm){(initStandardLeft_rel(sm[1]));}}, + {std::regex("standard_RIGHT_rel (.+)"), + [this](auto sm){(initStandardRight_rel(sm[1]));}}, {std::regex("eager_LEFT"), [this](auto){(initEagerLeft());}}, {std::regex("eager_RIGHT"), @@ -233,7 +239,7 @@ void Transition::initSplit(int index) }; } -void Transition::initShift() +void Transition::initEagerShift() { sequence.emplace_back(Action::pushWordIndexOnStack()); sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged()); @@ -250,6 +256,20 @@ void Transition::initShift() }; } +void Transition::initStandardShift() +{ + sequence.emplace_back(Action::pushWordIndexOnStack()); + sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged()); + + cost = [](const Config & config) + { + if (config.hasStack(0) && config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1) + return std::numeric_limits<int>::max(); + + return 0; + }; +} + void Transition::initEagerLeft_rel(std::string label) { sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0)); @@ -276,6 +296,32 @@ void Transition::initEagerLeft_rel(std::string label) }; } +void Transition::initStandardLeft_rel(std::string label) +{ + sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Stack, 1)); + sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 1, label)); + sequence.emplace_back(Action::popStack(1)); + + cost = [label](const Config & config) + { + auto stackIndex = config.getStack(1); + auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); + auto wordIndex = config.getStack(0); + if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) + return std::numeric_limits<int>::max(); + + int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, stackIndex, config); + + if (stackGovIndex != std::to_string(wordIndex)) + ++cost; + + if (label != config.getConst(Config::deprelColName, stackIndex, 0)) + ++cost; + + return cost; + }; +} + void Transition::initEagerLeft() { sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0)); @@ -329,6 +375,33 @@ void Transition::initEagerRight_rel(std::string label) }; } +void Transition::initStandardRight_rel(std::string label) +{ + sequence.emplace_back(Action::attach(Config::Object::Stack, 1, Config::Object::Stack, 0)); + sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label)); + sequence.emplace_back(Action::popStack(0)); + + cost = [label](const Config & config) + { + auto stackIndex = config.getStack(1); + auto wordIndex = config.getStack(0); + if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) + return std::numeric_limits<int>::max(); + + auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0); + + int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, wordIndex, config); + + if (bufferGovIndex != std::to_string(stackIndex)) + ++cost; + + if (label != config.getConst(Config::deprelColName, wordIndex, 0)) + ++cost; + + return cost; + }; +} + void Transition::initEagerRight() { sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));