diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 34b3fdb097ca16674c73136b4cfd06b3d312504b..12e4e4a8ef8665f71ce106aa93721f8ca5f57cb5 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -442,11 +442,14 @@ void Transition::initEagerLeft_rel(std::string label) { auto depIndex = config.getStack(0); auto govIndex = config.getWordIndex(); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config); if (label != config.getConst(Config::deprelColName, depIndex, 0)) ++cost; + if (depGovIndex != std::to_string(govIndex)) + ++cost; return cost; }; @@ -544,18 +547,28 @@ void Transition::initEagerLeft() costDynamic = [](const Config & config) { auto depIndex = config.getStack(0); - auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); auto govIndex = config.getWordIndex(); - - if (depGovIndex == std::to_string(govIndex)) - return 0; + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config); + if (depGovIndex != std::to_string(govIndex)) + cost += 1; + return cost; }; - costStatic = costDynamic; + costStatic = [](const Config & config) + { + auto depIndex = config.getStack(0); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); + auto govIndex = config.getWordIndex(); + + if (depGovIndex == std::to_string(govIndex)) + return 0; + + return 1; + }; } void Transition::initEagerRight_rel(std::string label) @@ -566,13 +579,17 @@ void Transition::initEagerRight_rel(std::string label) costDynamic = [label](const Config & config) { + auto govIndex = config.getStack(0); auto depIndex = config.getWordIndex(); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config); cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config); if (label != config.getConst(Config::deprelColName, depIndex, 0)) ++cost; + if (depGovIndex == std::to_string(govIndex)) + ++cost; return cost; }; @@ -669,20 +686,31 @@ void Transition::initEagerRight() costDynamic = [](const Config & config) { - auto govIndex = config.getStack(0); auto depIndex = config.getWordIndex(); + auto govIndex = config.getStack(0); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); - if (depGovIndex == std::to_string(govIndex)) - return 0; - int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config); cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config); + if (depGovIndex != std::to_string(govIndex)) + cost += 1; + return cost; }; - costStatic = costDynamic; + costStatic = [](const Config & config) + { + auto govIndex = config.getStack(0); + auto depIndex = config.getWordIndex(); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); + + if (depGovIndex == std::to_string(govIndex)) + return 0; + + return 1; + }; + } void Transition::initReduce_strict()