diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index 91cc04a384419560f24e64a7ccbc56249f4ced17..422cbdd15af21b3c5ffeb60b70df410f755524a2 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -17,6 +17,8 @@ class Transition private : static int getNbLinkedWith(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config); + static int getNbLinkedWithHead(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config); + static int getNbLinkedWithDeps(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config); static int getFirstIndexOfSentence(int baseIndex, const Config & config); static int getLastIndexOfSentence(int baseIndex, const Config & config); diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index c2f6c8d4bca1558f2309ced366f9b12c2de0ab6a..2e3400c23d4578f08dfcf6458328a58394b3d633 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -246,9 +246,6 @@ void Transition::initEagerShift() cost = [](const Config & config) { - if (config.hasStack(0) && config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1) - return std::numeric_limits<int>::max(); - if (!config.isToken(config.getWordIndex())) return 0; @@ -263,9 +260,6 @@ void Transition::initStandardShift() cost = [](const Config & config) { - if (config.hasStack(0) && config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1) - return std::numeric_limits<int>::max(); - return 0; }; } @@ -278,18 +272,16 @@ void Transition::initEagerLeft_rel(std::string label) cost = [label](const Config & config) { - auto stackIndex = config.getStack(0); - auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); - auto wordIndex = config.getWordIndex(); - if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) - return std::numeric_limits<int>::max(); + auto depIndex = config.getStack(0); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); + auto govIndex = config.getWordIndex(); - int cost = getNbLinkedWith(wordIndex+1, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config); + if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0)) + return 0; - if (stackGovIndex != std::to_string(wordIndex)) - ++cost; + int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config); - if (label != config.getConst(Config::deprelColName, stackIndex, 0)) + if (label != config.getConst(Config::deprelColName, depIndex, 0)) ++cost; return cost; @@ -304,18 +296,17 @@ void Transition::initStandardLeft_rel(std::string label) cost = [label](const Config & config) { - auto stackIndex = config.getStack(1); - auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); - auto wordIndex = config.getStack(0); - if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) - return std::numeric_limits<int>::max(); + auto depIndex = config.getStack(1); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); + auto govIndex = config.getStack(0); - int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, stackIndex, config); + if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0)) + return 0; - if (stackGovIndex != std::to_string(wordIndex)) - ++cost; + int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, depIndex, config); + cost += getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config); - if (label != config.getConst(Config::deprelColName, stackIndex, 0)) + if (label != config.getConst(Config::deprelColName, depIndex, 0)) ++cost; return cost; @@ -329,16 +320,14 @@ void Transition::initEagerLeft() cost = [](const Config & config) { - auto stackIndex = config.getStack(0); - auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0); - auto wordIndex = config.getWordIndex(); - if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) - return std::numeric_limits<int>::max(); + auto depIndex = config.getStack(0); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); + auto govIndex = config.getWordIndex(); - int cost = getNbLinkedWith(wordIndex+1, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config); + if (depGovIndex == std::to_string(govIndex)) + return 0; - if (stackGovIndex != std::to_string(wordIndex)) - ++cost; + int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config); return cost; }; @@ -353,22 +342,17 @@ void Transition::initEagerRight_rel(std::string label) cost = [label](const Config & config) { - auto stackIndex = config.getStack(0); - auto wordIndex = config.getWordIndex(); - if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) - return std::numeric_limits<int>::max(); + auto govIndex = config.getStack(0); + auto depIndex = config.getWordIndex(); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); - auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0); - - if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1) - return std::numeric_limits<int>::max(); - - int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config); + if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0)) + return 0; - if (bufferGovIndex != std::to_string(stackIndex)) - ++cost; + int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config); + cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config); - if (label != config.getConst(Config::deprelColName, wordIndex, 0)) + if (label != config.getConst(Config::deprelColName, depIndex, 0)) ++cost; return cost; @@ -383,19 +367,17 @@ void Transition::initStandardRight_rel(std::string label) cost = [label](const Config & config) { - auto stackIndex = config.getStack(1); - auto wordIndex = config.getStack(0); - if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) - return std::numeric_limits<int>::max(); + auto govIndex = config.getStack(1); + auto depIndex = config.getStack(0); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); - auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0); + if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0)) + return 0; - int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, wordIndex, config); + int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, depIndex, config); + cost += getNbLinkedWith(2, config.getStackSize()-1, Config::Object::Stack, depIndex, config); - if (bufferGovIndex != std::to_string(stackIndex)) - ++cost; - - if (label != config.getConst(Config::deprelColName, wordIndex, 0)) + if (label != config.getConst(Config::deprelColName, depIndex, 0)) ++cost; return cost; @@ -410,20 +392,15 @@ void Transition::initEagerRight() cost = [](const Config & config) { - auto stackIndex = config.getStack(0); - auto wordIndex = config.getWordIndex(); - if (!(config.isToken(stackIndex) && config.isToken(wordIndex))) - return std::numeric_limits<int>::max(); - - auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0); + auto govIndex = config.getStack(0); + auto depIndex = config.getWordIndex(); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); - if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1) - return std::numeric_limits<int>::max(); + if (depGovIndex == std::to_string(govIndex)) + return 0; - int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config); - - if (bufferGovIndex != std::to_string(stackIndex)) - ++cost; + int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config); + cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config); return cost; }; @@ -442,10 +419,7 @@ void Transition::initReduce_strict() if (!config.isToken(stackIndex)) return 0; - int cost = getNbLinkedWith(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config); - - if (config.getConst(Config::EOSColName, stackIndex, 0) == Config::EOSSymbol1) - ++cost; + int cost = getNbLinkedWithDeps(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config); return cost; }; @@ -465,9 +439,6 @@ void Transition::initReduce_relaxed() int cost = getNbLinkedWith(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config); - if (config.getConst(Config::EOSColName, stackIndex, 0) == Config::EOSSymbol1) - ++cost; - return cost; }; } @@ -527,6 +498,52 @@ int Transition::getNbLinkedWith(int firstIndex, int lastIndex, Config::Object ob return nbLinkedWith; } +int Transition::getNbLinkedWithHead(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config) +{ + auto govIndex = config.getConst(Config::headColName, withIndex, 0); + auto govIndexPredicted = config.getAsFeature(Config::headColName, withIndex); + + int nbLinkedWith = 0; + + for (int i = firstIndex; i <= lastIndex; ++i) + { + int index = i; + if (object == Config::Object::Stack) + index = config.getStack(i); + + if (!config.isToken(index)) + continue; + + if (govIndex == std::to_string(index) and util::isEmpty(govIndexPredicted)) + ++nbLinkedWith; + } + + return nbLinkedWith; +} + +int Transition::getNbLinkedWithDeps(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config) +{ + int nbLinkedWith = 0; + + for (int i = firstIndex; i <= lastIndex; ++i) + { + int index = i; + if (object == Config::Object::Stack) + index = config.getStack(i); + + if (!config.isToken(index)) + continue; + + auto otherGovIndex = config.getConst(Config::headColName, index, 0); + auto otherGovIndexPredicted = config.getAsFeature(Config::headColName, index); + + if (otherGovIndex == std::to_string(withIndex) and util::isEmpty(otherGovIndexPredicted)) + ++nbLinkedWith; + } + + return nbLinkedWith; +} + int Transition::getFirstIndexOfSentence(int baseIndex, const Config & config) { int firstIndex = baseIndex;