From bb7307d4737006ad5e74572a3796d1ed560be0b0 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 22 Feb 2021 17:51:02 +0100 Subject: [PATCH] Fixed oracles of non rel arc eager --- reading_machine/src/Transition.cpp | 48 +++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 34b3fdb..12e4e4a 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -442,11 +442,14 @@ void Transition::initEagerLeft_rel(std::string label) { auto depIndex = config.getStack(0); auto govIndex = config.getWordIndex(); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config); if (label != config.getConst(Config::deprelColName, depIndex, 0)) ++cost; + if (depGovIndex != std::to_string(govIndex)) + ++cost; return cost; }; @@ -544,18 +547,28 @@ void Transition::initEagerLeft() costDynamic = [](const Config & config) { auto depIndex = config.getStack(0); - auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); auto govIndex = config.getWordIndex(); - - if (depGovIndex == std::to_string(govIndex)) - return 0; + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config); + if (depGovIndex != std::to_string(govIndex)) + cost += 1; + return cost; }; - costStatic = costDynamic; + costStatic = [](const Config & config) + { + auto depIndex = config.getStack(0); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); + auto govIndex = config.getWordIndex(); + + if (depGovIndex == std::to_string(govIndex)) + return 0; + + return 1; + }; } void Transition::initEagerRight_rel(std::string label) @@ -566,13 +579,17 @@ void Transition::initEagerRight_rel(std::string label) costDynamic = [label](const Config & config) { + auto govIndex = config.getStack(0); auto depIndex = config.getWordIndex(); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config); cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config); if (label != config.getConst(Config::deprelColName, depIndex, 0)) ++cost; + if (depGovIndex == std::to_string(govIndex)) + ++cost; return cost; }; @@ -669,20 +686,31 @@ void Transition::initEagerRight() costDynamic = [](const Config & config) { - auto govIndex = config.getStack(0); auto depIndex = config.getWordIndex(); + auto govIndex = config.getStack(0); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); - if (depGovIndex == std::to_string(govIndex)) - return 0; - int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config); cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config); + if (depGovIndex != std::to_string(govIndex)) + cost += 1; + return cost; }; - costStatic = costDynamic; + costStatic = [](const Config & config) + { + auto govIndex = config.getStack(0); + auto depIndex = config.getWordIndex(); + auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); + + if (depGovIndex == std::to_string(govIndex)) + return 0; + + return 1; + }; + } void Transition::initReduce_strict() -- GitLab