From 03221c3fccbd4f68ce6a4ae0b441f137b87f97e5 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 3 Apr 2020 18:59:03 +0200 Subject: [PATCH] Corrected a bug where splitword had 0 cost even when it didn't match the size of the gold multiword --- reading_machine/include/Config.hpp | 2 +- reading_machine/src/Transition.cpp | 6 ++++++ reading_machine/src/TransitionSet.cpp | 5 ++++- trainer/src/Trainer.cpp | 18 ++++++++++++++++++ 4 files changed, 29 insertions(+), 2 deletions(-) diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index 0bc46f5..ea092f0 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -22,7 +22,7 @@ class Config static constexpr const char * deprelColName = "DEPREL"; static constexpr const char * idColName = "ID"; static constexpr int nbHypothesesMax = 1; - static constexpr int maxNbAppliableSplitTransitions = 3; + static constexpr int maxNbAppliableSplitTransitions = 8; public : diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index d45a794..3d27545 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -203,6 +203,12 @@ void Transition::initSplitWord(std::vector<std::string> words) cost = [words](const Config & config) { + if (!config.isMultiword(config.getWordIndex())) + return std::numeric_limits<int>::max(); + + if (config.getMultiwordSize(config.getWordIndex())+2 != (int)words.size()) + return std::numeric_limits<int>::max(); + int cost = 0; for (unsigned int i = 0; i < words.size(); i++) if (!config.has("FORM", config.getWordIndex()+i, 0) or config.getConst("FORM", config.getWordIndex()+i, 0) != words[i]) diff --git a/reading_machine/src/TransitionSet.cpp b/reading_machine/src/TransitionSet.cpp index 8fbb6a5..5d0df94 100644 --- a/reading_machine/src/TransitionSet.cpp +++ b/reading_machine/src/TransitionSet.cpp @@ -46,10 +46,13 @@ std::vector<Transition *> TransitionSet::getNAppliableTransitions(const Config & { std::vector<Transition *> result; - for (unsigned int i = 0; i < transitions.size() && result.size() < n; i++) + for (unsigned int i = 0; i < transitions.size(); i++) if (transitions[i].appliable(c)) result.emplace_back(&transitions[i]); + if ((int)result.size() > n) + util::myThrow(fmt::format("there are {} appliable transitions n = {}\n", result.size(), n)); + return result; } diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 0e61711..57e7e51 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -52,6 +52,24 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch: util::myThrow("No transition appliable !"); } + if (config.isMultiword(config.getWordIndex())) + if (transition->getName() == "ADDCHARTOWORD") + { + config.printForDebug(stderr); + + auto & splitTrans = config.getAppliableSplitTransitions(); + fmt::print(stderr, "splitTrans.size() = {}\n", splitTrans.size()); + for (auto & trans : splitTrans) + fmt::print(stderr, "cost {} : '{}'\n", trans->getCost(config), trans->getName()); + util::myThrow(fmt::format("Transition should have been a split")); + } + if (transition->getName() == "ENDWORD") + if (config.getAsFeature("FORM",config.getWordIndex()) != config.getConst("FORM",config.getWordIndex(),0)) + { + config.printForDebug(stderr); + util::myThrow(fmt::format("Words don't match")); + } + std::vector<std::vector<long>> context; try -- GitLab