diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index a5b43617410a1898469055ba340c2a8810ef7604..c24b5eb76808bc2fa874a702a5a10ac0c94ca8a7 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -12,7 +12,8 @@ class Transition std::string name; std::string state; std::vector<Action> sequence; - std::function<int(const Config & config)> cost; + std::function<int(const Config & config)> costDynamic; + std::function<int(const Config & config)> costStatic; private : @@ -54,7 +55,8 @@ class Transition Transition(const std::string & name); void apply(Config & config); bool appliable(const Config & config) const; - int getCost(const Config & config) const; + int getCostDynamic(const Config & config) const; + int getCostStatic(const Config & config) const; const std::string & getName() const; }; diff --git a/reading_machine/include/TransitionSet.hpp b/reading_machine/include/TransitionSet.hpp index 8f7b733b50dc6d8feba2cbd88b585fbb2b52e692..1d1bc7543de401c86afabe1c4f9c6db4caaadb6c 100644 --- a/reading_machine/include/TransitionSet.hpp +++ b/reading_machine/include/TransitionSet.hpp @@ -20,8 +20,8 @@ class TransitionSet TransitionSet(const std::vector<std::string> & filenames); TransitionSet(const std::string & filename); - std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c); - Transition * getBestAppliableTransition(const Config & c); + std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c, bool dynamic = false); + Transition * getBestAppliableTransition(const Config & c, const std::vector<int> & appliableTransitions, bool dynamic = false); std::vector<Transition *> getNAppliableTransitions(const Config & c, int n); std::vector<int> getAppliableTransitions(const Config & c); std::size_t getTransitionIndex(const Transition * transition) const; diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 9aed16e4aeadf9459a37a05b913cb6adceb08a43..3a717416131bd398bc3f805dfc1b10404ca0ff76 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -107,9 +107,17 @@ bool Transition::appliable(const Config & config) const return true; } -int Transition::getCost(const Config & config) const +int Transition::getCostDynamic(const Config & config) const { - try {return cost(config);} + try {return costDynamic(config);} + catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what()));} + + return 0; +} + +int Transition::getCostStatic(const Config & config) const +{ + try {return costStatic(config);} catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what()));} return 0; @@ -127,7 +135,7 @@ void Transition::initWrite(std::string colName, std::string object, std::string sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value)); - cost = [colName, objectValue, indexValue, value](const Config & config) + costDynamic = [colName, objectValue, indexValue, value](const Config & config) { int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); @@ -145,7 +153,7 @@ void Transition::initAdd(std::string colName, std::string object, std::string in sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value)); - cost = [colName, objectValue, indexValue, value](const Config & config) + costDynamic = [colName, objectValue, indexValue, value](const Config & config) { int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); @@ -161,7 +169,7 @@ void Transition::initAdd(std::string colName, std::string object, std::string in void Transition::initNothing() { - cost = [](const Config &) + costDynamic = [](const Config &) { return 0; }; @@ -171,7 +179,7 @@ void Transition::initIgnoreChar() { sequence.emplace_back(Action::ignoreCurrentCharacter()); - cost = [](const Config &) + costDynamic = [](const Config &) { return 0; }; @@ -181,7 +189,7 @@ void Transition::initEndWord() { sequence.emplace_back(Action::endWord()); - cost = [](const Config & config) + costDynamic = [](const Config & config) { if (config.getConst("FORM", config.getWordIndex(), 0) == config.getAsFeature("FORM", config.getWordIndex())) return 0; @@ -196,7 +204,7 @@ void Transition::initAddCharToWord() sequence.emplace_back(Action::addCurCharToCurWord()); sequence.emplace_back(Action::moveCharacterIndex(1)); - cost = [](const Config & config) + costDynamic = [](const Config & config) { if (!config.hasCharacter(config.getCharacterIndex())) return std::numeric_limits<int>::max(); @@ -226,7 +234,7 @@ void Transition::initSplitWord(std::vector<std::string> words) sequence.emplace_back(Action::addHypothesisRelativeRelaxed("FORM", Config::Object::Buffer, i, words[i])); sequence.emplace_back(Action::setMultiwordIds(words.size()-1)); - cost = [words](const Config & config) + costDynamic = [words](const Config & config) { if (!config.isMultiword(config.getWordIndex())) return std::numeric_limits<int>::max(); @@ -247,14 +255,14 @@ void Transition::initSplit(int index) { sequence.emplace_back(Action::split(index)); - cost = [index](const Config & config) + costDynamic = [index](const Config & config) { auto & transitions = config.getAppliableSplitTransitions(); if (index < 0 or index >= (int)transitions.size()) return std::numeric_limits<int>::max(); - return transitions[index]->getCost(config); + return transitions[index]->getCostDynamic(config); }; } @@ -263,13 +271,18 @@ void Transition::initEagerShift() sequence.emplace_back(Action::pushWordIndexOnStack()); sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged()); - cost = [](const Config & config) + costDynamic = [](const Config & config) { if (!config.isToken(config.getWordIndex())) return 0; return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config); }; + + costStatic = [](const Config &) + { + return 0; + }; } void Transition::initStandardShift() @@ -277,7 +290,7 @@ void Transition::initStandardShift() sequence.emplace_back(Action::pushWordIndexOnStack()); sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged()); - cost = [](const Config & config) + costDynamic = [](const Config & config) { return 0; }; @@ -289,7 +302,7 @@ void Transition::initEagerLeft_rel(std::string label) sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label)); sequence.emplace_back(Action::popStack(0)); - cost = [label](const Config & config) + costDynamic = [label](const Config & config) { auto depIndex = config.getStack(0); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); @@ -305,6 +318,18 @@ void Transition::initEagerLeft_rel(std::string label) return cost; }; + + costStatic = [label](const Config & config) + { + auto depIndex = config.getStack(0); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); + auto govIndex = config.getWordIndex(); + + if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0)) + return 0; + + return 1; + }; } void Transition::initStandardLeft_rel(std::string label) @@ -313,7 +338,7 @@ void Transition::initStandardLeft_rel(std::string label) sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 1, label)); sequence.emplace_back(Action::popStack(1)); - cost = [label](const Config & config) + costDynamic = [label](const Config & config) { auto depIndex = config.getStack(1); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); @@ -337,7 +362,7 @@ void Transition::initEagerLeft() sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0)); sequence.emplace_back(Action::popStack(0)); - cost = [](const Config & config) + costDynamic = [](const Config & config) { auto depIndex = config.getStack(0); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); @@ -359,7 +384,7 @@ void Transition::initEagerRight_rel(std::string label) sequence.emplace_back(Action::pushWordIndexOnStack()); sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged()); - cost = [label](const Config & config) + costDynamic = [label](const Config & config) { auto govIndex = config.getStack(0); auto depIndex = config.getWordIndex(); @@ -376,6 +401,18 @@ void Transition::initEagerRight_rel(std::string label) return cost; }; + + costStatic = [label](const Config & config) + { + auto govIndex = config.getStack(0); + auto depIndex = config.getWordIndex(); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); + + if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0)) + return 0; + + return 1; + }; } void Transition::initStandardRight_rel(std::string label) @@ -384,7 +421,7 @@ void Transition::initStandardRight_rel(std::string label) sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label)); sequence.emplace_back(Action::popStack(0)); - cost = [label](const Config & config) + costDynamic = [label](const Config & config) { auto govIndex = config.getStack(1); auto depIndex = config.getStack(0); @@ -409,7 +446,7 @@ void Transition::initEagerRight() sequence.emplace_back(Action::pushWordIndexOnStack()); sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged()); - cost = [](const Config & config) + costDynamic = [](const Config & config) { auto govIndex = config.getStack(0); auto depIndex = config.getWordIndex(); @@ -430,7 +467,20 @@ void Transition::initReduce_strict() sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0)); sequence.emplace_back(Action::popStack(0)); - cost = [](const Config & config) + costDynamic = [](const Config & config) + { + auto stackIndex = config.getStack(0); + auto wordIndex = config.getWordIndex(); + + if (!config.isToken(stackIndex)) + return 0; + + int cost = getNbLinkedWithDeps(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config); + + return cost; + }; + + costDynamic = [](const Config & config) { auto stackIndex = config.getStack(0); auto wordIndex = config.getWordIndex(); @@ -442,13 +492,15 @@ void Transition::initReduce_strict() return cost; }; + + costStatic = costDynamic; } void Transition::initReduce_relaxed() { sequence.emplace_back(Action::popStack(0)); - cost = [](const Config & config) + costDynamic = [](const Config & config) { auto stackIndex = config.getStack(0); auto wordIndex = config.getWordIndex(); @@ -469,7 +521,7 @@ void Transition::initEOS(int bufferIndex) sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Config::Object::Buffer, bufferIndex, Config::EOSSymbol1)); sequence.emplace_back(Action::emptyStack()); - cost = [bufferIndex](const Config & config) + costDynamic = [bufferIndex](const Config & config) { int lineIndex = config.getRelativeWordIndex(Config::Buffer, bufferIndex); if (config.getConst(Config::EOSColName, lineIndex, 0) != Config::EOSSymbol1) @@ -483,7 +535,7 @@ void Transition::initDeprel(std::string label) { sequence.emplace_back(Action::deprel(label)); - cost = [label](const Config & config) + costDynamic = [label](const Config & config) { return config.getConst(Config::deprelColName, config.getLastAttached(), 0) == label ? 0 : 1; }; @@ -509,7 +561,7 @@ void Transition::initTransformSuffix(std::string fromCol, std::string fromObj, s toAddUtf8 = util::splitAsUtf8(toAdd); sequence.emplace_back(Action::transformSuffix(fromCol, fromObjectValue, fromIndexValue, toCol, toObjectValue, toIndexValue, toRemoveUtf8, toAddUtf8)); - cost = [fromObjectValue, fromIndexValue, toObjectValue, toIndexValue, toAddUtf8, toRemoveUtf8, fromCol, toCol](const Config & config) + costDynamic = [fromObjectValue, fromIndexValue, toObjectValue, toIndexValue, toAddUtf8, toRemoveUtf8, fromCol, toCol](const Config & config) { int fromLineIndex = config.getRelativeWordIndex(fromObjectValue, fromIndexValue); int toLineIndex = config.getRelativeWordIndex(toObjectValue, toIndexValue); @@ -533,7 +585,7 @@ void Transition::initUppercase(std::string col, std::string obj, std::string ind sequence.emplace_back(Action::uppercase(col, objectValue, indexValue)); - cost = [col, objectValue, indexValue](const Config & config) + costDynamic = [col, objectValue, indexValue](const Config & config) { int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); auto & expectedValue = config.getConst(col, lineIndex, 0); @@ -556,7 +608,7 @@ void Transition::initUppercaseIndex(std::string col, std::string obj, std::strin sequence.emplace_back(Action::uppercaseIndex(col, objectValue, indexValue, inIndexValue)); - cost = [col, objectValue, indexValue, inIndexValue](const Config & config) + costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config) { int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); auto & expectedValue = config.getConst(col, lineIndex, 0); @@ -580,7 +632,7 @@ void Transition::initLowercase(std::string col, std::string obj, std::string ind sequence.emplace_back(Action::lowercase(col, objectValue, indexValue)); - cost = [col, objectValue, indexValue](const Config & config) + costDynamic = [col, objectValue, indexValue](const Config & config) { int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); auto & expectedValue = config.getConst(col, lineIndex, 0); @@ -603,7 +655,7 @@ void Transition::initLowercaseIndex(std::string col, std::string obj, std::strin sequence.emplace_back(Action::lowercaseIndex(col, objectValue, indexValue, inIndexValue)); - cost = [col, objectValue, indexValue, inIndexValue](const Config & config) + costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config) { int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); auto & expectedValue = config.getConst(col, lineIndex, 0); diff --git a/reading_machine/src/TransitionSet.cpp b/reading_machine/src/TransitionSet.cpp index 5701c70a46a9dabc04cbe53b8058ff5676ce62f7..8c5c1a88712b155d21e0ee4615394c70e639c25a 100644 --- a/reading_machine/src/TransitionSet.cpp +++ b/reading_machine/src/TransitionSet.cpp @@ -35,14 +35,14 @@ void TransitionSet::addTransitionsFromFile(const std::string & filename) std::fclose(file); } -std::vector<std::pair<Transition*, int>> TransitionSet::getAppliableTransitionsCosts(const Config & c) +std::vector<std::pair<Transition*, int>> TransitionSet::getAppliableTransitionsCosts(const Config & c, bool dynamic) { using Pair = std::pair<Transition*, int>; std::vector<Pair> appliableTransitions; for (unsigned int i = 0; i < transitions.size(); i++) if (transitions[i].appliable(c)) - appliableTransitions.emplace_back(&transitions[i], transitions[i].getCost(c)); + appliableTransitions.emplace_back(&transitions[i], dynamic ? transitions[i].getCostDynamic(c) : transitions[i].getCostStatic(c)); std::sort(appliableTransitions.begin(), appliableTransitions.end(), [](const Pair & a, const Pair & b) @@ -80,17 +80,17 @@ std::vector<int> TransitionSet::getAppliableTransitions(const Config & c) return result; } -Transition * TransitionSet::getBestAppliableTransition(const Config & c) +Transition * TransitionSet::getBestAppliableTransition(const Config & c, const std::vector<int> & appliableTransitions, bool dynamic) { Transition * result = nullptr; int bestCost = std::numeric_limits<int>::max(); for (unsigned int i = 0; i < transitions.size(); i++) { - if (!transitions[i].appliable(c)) + if (!appliableTransitions[i]) continue; - int cost = transitions[i].getCost(c); + int cost = dynamic ? transitions[i].getCostDynamic(c) : transitions[i].getCostStatic(c); if (cost == 0) return &transitions[i]; diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index a12c6c547e293a50443606ff23c342331154b1c1..ea9031be23d290a4d30682795eb1af9603cab9db 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -73,7 +73,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p Transition * transition = nullptr; Transition * goldTransition = nullptr; - goldTransition = machine.getTransitionSet().getBestAppliableTransition(config); + goldTransition = machine.getTransitionSet().getBestAppliableTransition(config, appliableTransitions, dynamicOracle); if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") { @@ -301,7 +301,7 @@ void Trainer::fillDicts(SubConfig & config, bool debug) } Transition * goldTransition = nullptr; - goldTransition = machine.getTransitionSet().getBestAppliableTransition(config); + goldTransition = machine.getTransitionSet().getBestAppliableTransition(config, appliableTransitions); if (!goldTransition) {