diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index 0bc46f525392d9854b4b24bef524b97f549bf197..ea092f08103127795c0453dd5f5811b709052c3a 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -22,7 +22,7 @@ class Config static constexpr const char * deprelColName = "DEPREL"; static constexpr const char * idColName = "ID"; static constexpr int nbHypothesesMax = 1; - static constexpr int maxNbAppliableSplitTransitions = 3; + static constexpr int maxNbAppliableSplitTransitions = 8; public : diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index d45a7945c494c235ff9d2d550560bdc3eb60148e..3d27545c00bff1741fe4021a300da2623341bbf4 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -203,6 +203,12 @@ void Transition::initSplitWord(std::vector<std::string> words) cost = [words](const Config & config) { + if (!config.isMultiword(config.getWordIndex())) + return std::numeric_limits<int>::max(); + + if (config.getMultiwordSize(config.getWordIndex())+2 != (int)words.size()) + return std::numeric_limits<int>::max(); + int cost = 0; for (unsigned int i = 0; i < words.size(); i++) if (!config.has("FORM", config.getWordIndex()+i, 0) or config.getConst("FORM", config.getWordIndex()+i, 0) != words[i]) diff --git a/reading_machine/src/TransitionSet.cpp b/reading_machine/src/TransitionSet.cpp index 8fbb6a5f5003d980deb99e597b71c2d28618766f..5d0df94fb8748855ad426b36a9a4b8002f9246b4 100644 --- a/reading_machine/src/TransitionSet.cpp +++ b/reading_machine/src/TransitionSet.cpp @@ -46,10 +46,13 @@ std::vector<Transition *> TransitionSet::getNAppliableTransitions(const Config & { std::vector<Transition *> result; - for (unsigned int i = 0; i < transitions.size() && result.size() < n; i++) + for (unsigned int i = 0; i < transitions.size(); i++) if (transitions[i].appliable(c)) result.emplace_back(&transitions[i]); + if ((int)result.size() > n) + util::myThrow(fmt::format("there are {} appliable transitions n = {}\n", result.size(), n)); + return result; } diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 0e61711696916342b2224eeca9bac5db58d6b298..57e7e51d014c75836118fad853924370900e96e6 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -52,6 +52,24 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch: util::myThrow("No transition appliable !"); } + if (config.isMultiword(config.getWordIndex())) + if (transition->getName() == "ADDCHARTOWORD") + { + config.printForDebug(stderr); + + auto & splitTrans = config.getAppliableSplitTransitions(); + fmt::print(stderr, "splitTrans.size() = {}\n", splitTrans.size()); + for (auto & trans : splitTrans) + fmt::print(stderr, "cost {} : '{}'\n", trans->getCost(config), trans->getName()); + util::myThrow(fmt::format("Transition should have been a split")); + } + if (transition->getName() == "ENDWORD") + if (config.getAsFeature("FORM",config.getWordIndex()) != config.getConst("FORM",config.getWordIndex(),0)) + { + config.printForDebug(stderr); + util::myThrow(fmt::format("Words don't match")); + } + std::vector<std::vector<long>> context; try