diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 4543513cfb50778c096a6ced53bae2c1762268f1..76601d3cfd0af61edc812da9ac5ec9656eeae52a 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -323,12 +323,8 @@ void Transition::initEagerLeft_rel(std::string label) costDynamic = [label](const Config & config) { auto depIndex = config.getStack(0); - auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); auto govIndex = config.getWordIndex(); - if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0)) - return 0; - int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config); if (label != config.getConst(Config::deprelColName, depIndex, 0)) @@ -366,7 +362,7 @@ void Transition::initStandardLeft_rel(std::string label) return 0; int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, depIndex, config); - cost += getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config); + cost += getNbLinkedWith(1, config.getStackSize()-2, Config::Object::Stack, depIndex, config); if (label != config.getConst(Config::deprelColName, depIndex, 0)) ++cost; @@ -408,14 +404,9 @@ void Transition::initEagerRight_rel(std::string label) costDynamic = [label](const Config & config) { - auto govIndex = config.getStack(0); auto depIndex = config.getWordIndex(); - auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); - if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0)) - return 0; - - int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config); + int cost = getNbLinkedWith(1, config.getStackSize()-2, Config::Object::Stack, depIndex, config); cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config); if (label != config.getConst(Config::deprelColName, depIndex, 0)) @@ -704,7 +695,6 @@ void Transition::initLowercaseIndex(std::string col, std::string obj, std::strin int Transition::getNbLinkedWith(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config) { auto govIndex = config.getConst(Config::headColName, withIndex, 0); - auto govIndexPredicted = config.getAsFeature(Config::headColName, withIndex); int nbLinkedWith = 0; @@ -718,11 +708,10 @@ int Transition::getNbLinkedWith(int firstIndex, int lastIndex, Config::Object ob continue; auto otherGovIndex = config.getConst(Config::headColName, index, 0); - auto otherGovIndexPredicted = config.getAsFeature(Config::headColName, index); - if (govIndex == std::to_string(index) and util::isEmpty(govIndexPredicted)) + if (govIndex == std::to_string(index)) ++nbLinkedWith; - if (otherGovIndex == std::to_string(withIndex) and util::isEmpty(otherGovIndexPredicted)) + if (otherGovIndex == std::to_string(withIndex)) ++nbLinkedWith; } @@ -732,7 +721,6 @@ int Transition::getNbLinkedWith(int firstIndex, int lastIndex, Config::Object ob int Transition::getNbLinkedWithHead(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config) { auto govIndex = config.getConst(Config::headColName, withIndex, 0); - auto govIndexPredicted = config.getAsFeature(Config::headColName, withIndex); int nbLinkedWith = 0; @@ -745,7 +733,7 @@ int Transition::getNbLinkedWithHead(int firstIndex, int lastIndex, Config::Objec if (!config.isToken(index)) continue; - if (govIndex == std::to_string(index) and util::isEmpty(govIndexPredicted)) + if (govIndex == std::to_string(index)) ++nbLinkedWith; } @@ -766,9 +754,8 @@ int Transition::getNbLinkedWithDeps(int firstIndex, int lastIndex, Config::Objec continue; auto otherGovIndex = config.getConst(Config::headColName, index, 0); - auto otherGovIndexPredicted = config.getAsFeature(Config::headColName, index); - if (otherGovIndex == std::to_string(withIndex) and util::isEmpty(otherGovIndexPredicted)) + if (otherGovIndex == std::to_string(withIndex)) ++nbLinkedWith; } diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index ff4f1a3088577afbd8c9bc1139a3ac0e1ca79495..1d363c5c66bdee84237887c2c4f255142790eda0 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -74,7 +74,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p Transition * transition = nullptr; Transition * goldTransition = nullptr; - goldTransition = machine.getTransitionSet().getBestAppliableTransition(config, appliableTransitions, dynamicOracle); + goldTransition = machine.getTransitionSet().getBestAppliableTransition(config, appliableTransitions, true or dynamicOracle); if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") {