diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index c24b5eb76808bc2fa874a702a5a10ac0c94ca8a7..b6671a3724b9da5c7a209bcdd8c8e1a3557ff147 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -14,6 +14,7 @@ class Transition std::vector<Action> sequence; std::function<int(const Config & config)> costDynamic; std::function<int(const Config & config)> costStatic; + std::function<bool(const Config & config)> precondition{[](const Config&){return true;}}; private : @@ -27,15 +28,19 @@ class Transition void initWrite(std::string colName, std::string object, std::string index, std::string value); void initAdd(std::string colName, std::string object, std::string index, std::string value); void initEagerShift(); + void initGoldEagerShift(); void initStandardShift(); void initEagerLeft_rel(std::string label); + void initGoldEagerLeft_rel(std::string label); void initEagerRight_rel(std::string label); + void initGoldEagerRight_rel(std::string label); void initStandardLeft_rel(std::string label); void initStandardRight_rel(std::string label); void initDeprel(std::string label); void initEagerLeft(); void initEagerRight(); void initReduce_strict(); + void initGoldReduce_strict(); void initReduce_relaxed(); void initEOS(int bufferIndex); void initNothing(); diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 7fcc7ca3fc3ad52672771262c2c233b6e2973ca7..352ca9b3e2903998b83da8d47ae524393a3392ba 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -11,16 +11,24 @@ Transition::Transition(const std::string & name) [this](auto sm){(initAdd(sm[3], sm[1], sm[2], sm[4]));}}, {std::regex("eager_SHIFT"), [this](auto){initEagerShift();}}, + {std::regex("gold_eager_SHIFT"), + [this](auto){initGoldEagerShift();}}, {std::regex("standard_SHIFT"), [this](auto){initStandardShift();}}, {std::regex("REDUCE_strict"), [this](auto){initReduce_strict();}}, + {std::regex("gold_REDUCE_strict"), + [this](auto){initGoldReduce_strict();}}, {std::regex("REDUCE_relaxed"), [this](auto){initReduce_relaxed();}}, {std::regex("eager_LEFT_rel (.+)"), [this](auto sm){(initEagerLeft_rel(sm[1]));}}, + {std::regex("gold_eager_LEFT_rel (.+)"), + [this](auto sm){(initGoldEagerLeft_rel(sm[1]));}}, {std::regex("eager_RIGHT_rel (.+)"), [this](auto sm){(initEagerRight_rel(sm[1]));}}, + {std::regex("gold_eager_RIGHT_rel (.+)"), + [this](auto sm){(initGoldEagerRight_rel(sm[1]));}}, {std::regex("standard_LEFT_rel (.+)"), [this](auto sm){(initStandardLeft_rel(sm[1]));}}, {std::regex("standard_RIGHT_rel (.+)"), @@ -99,6 +107,10 @@ bool Transition::appliable(const Config & config) const for (const Action & action : sequence) if (!action.appliable(config, action)) return false; + + if (!precondition(config)) + return false; + } catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what())); @@ -301,6 +313,33 @@ void Transition::initEagerShift() }; } +void Transition::initGoldEagerShift() +{ + sequence.emplace_back(Action::pushWordIndexOnStack()); + sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged()); + + costDynamic = [](const Config & config) + { + if (!config.isToken(config.getWordIndex())) + return 0; + + return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config); + }; + + costStatic = [](const Config &) + { + return 0; + }; + + precondition = [](const Config & config) + { + if (!config.isToken(config.getWordIndex())) + return true; + + return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config) == 0; + }; +} + void Transition::initStandardShift() { sequence.emplace_back(Action::pushWordIndexOnStack()); @@ -346,6 +385,51 @@ void Transition::initEagerLeft_rel(std::string label) }; } +void Transition::initGoldEagerLeft_rel(std::string label) +{ + sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0)); + sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label)); + sequence.emplace_back(Action::popStack(0)); + + costDynamic = [label](const Config & config) + { + auto depIndex = config.getStack(0); + auto govIndex = config.getWordIndex(); + + int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config); + + if (label != config.getConst(Config::deprelColName, depIndex, 0)) + ++cost; + + return cost; + }; + + costStatic = [label](const Config & config) + { + auto depIndex = config.getStack(0); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); + auto govIndex = config.getWordIndex(); + + if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0)) + return 0; + + return 1; + }; + + precondition = [label](const Config & config) + { + auto depIndex = config.getStack(0); + auto govIndex = config.getWordIndex(); + + int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config); + + if (label != config.getConst(Config::deprelColName, depIndex, 0)) + ++cost; + + return cost == 0; + }; +} + void Transition::initStandardLeft_rel(std::string label) { sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Stack, 1)); @@ -428,6 +512,52 @@ void Transition::initEagerRight_rel(std::string label) }; } +void Transition::initGoldEagerRight_rel(std::string label) +{ + sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0)); + sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label)); + sequence.emplace_back(Action::pushWordIndexOnStack()); + sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged()); + + costDynamic = [label](const Config & config) + { + auto depIndex = config.getWordIndex(); + + int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config); + cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config); + + if (label != config.getConst(Config::deprelColName, depIndex, 0)) + ++cost; + + return cost; + }; + + costStatic = [label](const Config & config) + { + auto govIndex = config.getStack(0); + auto depIndex = config.getWordIndex(); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); + + if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0)) + return 0; + + return 1; + }; + + precondition = [label](const Config & config) + { + auto depIndex = config.getWordIndex(); + + int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config); + cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config); + + if (label != config.getConst(Config::deprelColName, depIndex, 0)) + ++cost; + + return cost == 0; + }; +} + void Transition::initStandardRight_rel(std::string label) { sequence.emplace_back(Action::attach(Config::Object::Stack, 1, Config::Object::Stack, 0)); @@ -500,6 +630,40 @@ void Transition::initReduce_strict() costStatic = costDynamic; } +void Transition::initGoldReduce_strict() +{ + sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0)); + sequence.emplace_back(Action::popStack(0)); + + costDynamic = [](const Config & config) + { + auto stackIndex = config.getStack(0); + auto wordIndex = config.getWordIndex(); + + if (!config.isToken(stackIndex)) + return 0; + + int cost = getNbLinkedWithDeps(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config); + + return cost; + }; + + costStatic = costDynamic; + + precondition = [](const Config & config) + { + auto stackIndex = config.getStack(0); + auto wordIndex = config.getWordIndex(); + + if (!config.isToken(stackIndex)) + return true; + + int cost = getNbLinkedWithDeps(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config); + + return cost == 0; + }; +} + void Transition::initReduce_relaxed() { sequence.emplace_back(Action::popStack(0));