From d377af892116869d0bc81d7d4127c09385719212 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sun, 3 May 2020 16:10:22 +0200 Subject: [PATCH] modernized EOS --- reading_machine/include/Action.hpp | 17 +-- reading_machine/include/Config.hpp | 14 ++- reading_machine/include/Strategy.hpp | 2 +- reading_machine/include/Transition.hpp | 2 +- reading_machine/src/Action.cpp | 164 ++++++++----------------- reading_machine/src/BaseConfig.cpp | 5 + reading_machine/src/Config.cpp | 27 ++++ reading_machine/src/ReadingMachine.cpp | 2 +- reading_machine/src/Strategy.cpp | 3 +- reading_machine/src/Transition.cpp | 67 +++------- trainer/include/Trainer.hpp | 4 +- trainer/src/MacaonTrain.cpp | 2 +- trainer/src/Trainer.cpp | 11 +- 13 files changed, 136 insertions(+), 184 deletions(-) diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index 5561a81..2dc3191 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -21,12 +21,6 @@ class Action Check }; - enum Object - { - Buffer, - Stack - }; - private : Type type; @@ -44,22 +38,21 @@ class Action public : - static Object str2object(const std::string & s); static Action addLinesIfNeeded(int nbLines); static Action moveWordIndex(int movement); static Action moveCharacterIndex(int movement); static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis); static Action addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition); - static Action addHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & hypothesis); - static Action addToHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & addition); + static Action addHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis); + static Action addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition); static Action pushWordIndexOnStack(); static Action popStack(); static Action emptyStack(); - static Action setRoot(); - static Action updateIds(); + static Action setRoot(int bufferIndex); + static Action updateIds(int bufferIndex); static Action endWord(); static Action assertIsEmpty(const std::string & colName); - static Action attach(Object governorObject, int governorIndex, Object dependentObject, int dependentIndex); + static Action attach(Config::Object governorObject, int governorIndex, Config::Object dependentObject, int dependentIndex); static Action addCurCharToCurWord(); static Action ignoreCurrentCharacter(); static Action consumeCharacterIndex(util::utf8string consumed); diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index 3ae84fc..f73bba4 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -21,11 +21,18 @@ class Config static constexpr const char * headColName = "HEAD"; static constexpr const char * deprelColName = "DEPREL"; static constexpr const char * idColName = "ID"; + static constexpr const char * sentIdColName = "SENTID"; static constexpr const char * isMultiColName = "MULTI"; static constexpr const char * childsColName = "CHILDS"; static constexpr int nbHypothesesMax = 1; static constexpr int maxNbAppliableSplitTransitions = 8; + enum Object + { + Buffer, + Stack + }; + public : using String = boost::flyweight<std::string>; @@ -56,6 +63,8 @@ class Config public : + static Object str2object(const std::string & s); + virtual std::size_t getNbColumns() const = 0; virtual std::size_t getColIndex(const std::string & colName) const = 0; virtual bool hasColIndex(const std::string & colName) const = 0; @@ -78,6 +87,8 @@ class Config ValueIterator getIterator(int colIndex, int lineIndex, int hypothesisIndex); ConstValueIterator getConstIterator(int colIndex, int lineIndex, int hypothesisIndex) const; + long getRelativeWordIndex(int relativeIndex) const; + public : virtual ~Config() {} @@ -116,7 +127,8 @@ class Config bool rawInputOnlySeparatorsLeft() const; std::size_t getWordIndex() const; std::size_t getCharacterIndex() const; - long getRelativeWordIndex(int relativeIndex) const; + long getRelativeWordIndex(Object object, int relativeIndex) const; + bool hasRelativeWordIndex(Object object, int relativeIndex) const; const String & getHistory(int relativeIndex) const; std::size_t getStack(int relativeIndex) const; bool hasHistory(int relativeIndex) const; diff --git a/reading_machine/include/Strategy.hpp b/reading_machine/include/Strategy.hpp index 9c4edd5..706d0fa 100644 --- a/reading_machine/include/Strategy.hpp +++ b/reading_machine/include/Strategy.hpp @@ -31,7 +31,7 @@ class Strategy public : - Strategy(const std::vector<std::string_view> & lines); + Strategy(std::vector<std::string> lines); std::pair<std::string, int> getMovement(const Config & c, const std::string & transition); const std::string getInitialState() const; void reset(); diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index 87531db..487f34e 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -22,7 +22,7 @@ class Transition void initLeft(std::string label); void initRight(std::string label); void initReduce(); - void initEOS(); + void initEOS(int bufferIndex); void initNothing(); void initIgnoreChar(); void initEndWord(); diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index 324acb3..9587154 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -53,11 +53,11 @@ Action Action::setMultiwordIds(int multiwordSize) { auto apply = [multiwordSize](Config & config, Action & a) { - addHypothesisRelative(Config::idColName, Object::Buffer, 0, fmt::format("{}-{}", config.getCurrentWordId()+1, config.getCurrentWordId()+multiwordSize)).apply(config, a); + addHypothesisRelative(Config::idColName, Config::Object::Buffer, 0, fmt::format("{}-{}", config.getCurrentWordId()+1, config.getCurrentWordId()+multiwordSize)).apply(config, a); for (int i = 0; i < multiwordSize; i++) { - addHypothesisRelative(Config::idColName, Object::Buffer, i+1, fmt::format("{}", config.getCurrentWordId()+1+i)).apply(config, a); - addHypothesisRelative(Config::isMultiColName, Object::Buffer, i+1, Config::EOSSymbol1).apply(config, a); + addHypothesisRelative(Config::idColName, Config::Object::Buffer, i+1, fmt::format("{}", config.getCurrentWordId()+1+i)).apply(config, a); + addHypothesisRelative(Config::isMultiColName, Config::Object::Buffer, i+1, Config::EOSSymbol1).apply(config, a); } }; @@ -176,80 +176,58 @@ Action Action::addToHypothesis(const std::string & colName, std::size_t lineInde return {Type::Write, apply, undo, appliable}; } -Action Action::addToHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & addition) +Action Action::addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition) { auto apply = [colName, object, relativeIndex, addition](Config & config, Action & a) { - int lineIndex = 0; - if (object == Object::Buffer) - lineIndex = config.getWordIndex() + relativeIndex; - else - lineIndex = config.getStack(relativeIndex); + int lineIndex = config.getRelativeWordIndex(object, relativeIndex); return addToHypothesis(colName, lineIndex, addition).apply(config, a); }; auto undo = [colName, object, relativeIndex](Config & config, Action & a) { - int lineIndex = 0; - if (object == Object::Buffer) - lineIndex = config.getWordIndex() + relativeIndex; - else - lineIndex = config.getStack(relativeIndex); + int lineIndex = config.getRelativeWordIndex(object, relativeIndex); return addToHypothesis(colName, lineIndex, "").undo(config, a); }; auto appliable = [colName, object, relativeIndex, addition](const Config & config, const Action & a) { - int lineIndex = 0; - if (object == Object::Buffer) - lineIndex = config.getWordIndex() + relativeIndex; - else if (config.hasStack(relativeIndex)) - lineIndex = config.getStack(relativeIndex); - else + if (!config.hasRelativeWordIndex(object, relativeIndex)) return false; + int lineIndex = config.getRelativeWordIndex(object, relativeIndex); + return addToHypothesis(colName, lineIndex, addition).appliable(config, a); }; return {Type::Write, apply, undo, appliable}; } -Action Action::addHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & hypothesis) +Action Action::addHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis) { auto apply = [colName, object, relativeIndex, hypothesis](Config & config, Action & a) { - int lineIndex = 0; - if (object == Object::Buffer) - lineIndex = config.getWordIndex() + relativeIndex; - else - lineIndex = config.getStack(relativeIndex); + int lineIndex = config.getRelativeWordIndex(object, relativeIndex); return addHypothesis(colName, lineIndex, hypothesis).apply(config, a); }; auto undo = [colName, object, relativeIndex](Config & config, Action & a) { - int lineIndex = 0; - if (object == Object::Buffer) - lineIndex = config.getWordIndex() + relativeIndex; - else - lineIndex = config.getStack(relativeIndex); + int lineIndex = config.getRelativeWordIndex(object, relativeIndex); return addHypothesis(colName, lineIndex, "").undo(config, a); }; auto appliable = [colName, object, relativeIndex](const Config & config, const Action & a) { - int lineIndex = 0; - if (object == Object::Buffer) - lineIndex = config.getWordIndex() + relativeIndex; - else if (config.hasStack(relativeIndex)) - lineIndex = config.getStack(relativeIndex); - else + if (!config.hasRelativeWordIndex(object, relativeIndex)) return false; + int lineIndex = config.getRelativeWordIndex(object, relativeIndex); + return addHypothesis(colName, lineIndex, "").appliable(config, a); }; @@ -317,7 +295,7 @@ Action Action::endWord() auto apply = [](Config & config, Action & a) { config.setCurrentWordId(config.getCurrentWordId()+1); - addHypothesisRelative(Config::idColName, Object::Buffer, 0, std::to_string(config.getCurrentWordId())).apply(config, a); + addHypothesisRelative(Config::idColName, Config::Object::Buffer, 0, std::to_string(config.getCurrentWordId())).apply(config, a); if (!config.rawInputOnlySeparatorsLeft() and !config.has(0,config.getWordIndex()+1,0)) config.addLines(1); @@ -442,14 +420,14 @@ Action Action::addCurCharToCurWord() return {Type::Write, apply, undo, appliable}; } -Action Action::setRoot() +Action Action::setRoot(int bufferIndex) { - auto apply = [](Config & config, Action & a) + auto apply = [bufferIndex](Config & config, Action & a) { + int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex); int rootIndex = -1; - int firstSentenceIndex = -1; - for (int i = config.getStack(0); true; --i) + for (int i = lineIndex; true; --i) { if (!config.has(0, i, 0)) { @@ -460,19 +438,17 @@ Action Action::setRoot() if (!config.isTokenPredicted(i)) continue; - if (config.getLastNotEmptyHypConst(Config::EOSColName, i) == Config::EOSSymbol1) + if (config.getAsFeature(Config::EOSColName, i) == Config::EOSSymbol1) break; - firstSentenceIndex = i; - - if (util::isEmpty(config.getLastNotEmptyHypConst(Config::headColName, i))) + if (util::isEmpty(config.getAsFeature(Config::headColName, i))) { rootIndex = i; a.data.push_back(std::to_string(i)); } } - for (int i = config.getStack(0); true; --i) + for (int i = lineIndex; true; --i) { if (!config.has(0, i, 0)) { @@ -483,10 +459,10 @@ Action Action::setRoot() if (!config.isTokenPredicted(i)) continue; - if (config.getLastNotEmptyHypConst(Config::EOSColName, i) == Config::EOSSymbol1) + if (config.getAsFeature(Config::EOSColName, i) == Config::EOSSymbol1) break; - if (util::isEmpty(config.getLastNotEmptyHypConst(Config::headColName, i))) + if (util::isEmpty(config.getAsFeature(Config::headColName, i))) { if (i == rootIndex) { @@ -498,11 +474,6 @@ Action Action::setRoot() config.getFirstEmpty(Config::headColName, i) = std::to_string(rootIndex); } } - else - { - if (std::stoi(config.getLastNotEmptyHypConst(Config::headColName, i)) < firstSentenceIndex) - config.getFirstEmpty(Config::headColName, i) = std::to_string(rootIndex); - } } }; @@ -516,20 +487,23 @@ Action Action::setRoot() } }; - auto appliable = [](const Config & config, const Action &) + auto appliable = [bufferIndex](const Config & config, const Action &) { - return config.hasStack(0) and config.isTokenPredicted(config.getStack(0)) and config.getLastNotEmptyConst(Config::isMultiColName, config.getStack(0)) != Config::EOSSymbol1; + int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex); + return config.has(0,lineIndex,0) and config.isTokenPredicted(lineIndex) and config.getAsFeature(Config::isMultiColName, lineIndex) != Config::EOSSymbol1; }; return {Type::Write, apply, undo, appliable}; } -Action Action::updateIds() +Action Action::updateIds(int bufferIndex) { - auto apply = [](Config & config, Action & a) + auto apply = [bufferIndex](Config & config, Action & a) { + int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex); int firstIndexOfSentence = -1; - for (int i = config.getStack(0); true; --i) + int lastSentId = -1; + for (int i = lineIndex; true; --i) { if (!config.has(0, i, 0)) { @@ -541,7 +515,10 @@ Action Action::updateIds() continue; if (config.getLastNotEmptyHypConst(Config::EOSColName, i) == Config::EOSSymbol1) + { + lastSentId = std::stoi(config.getAsFeature(Config::sentIdColName, i)); break; + } firstIndexOfSentence = i; } @@ -549,18 +526,17 @@ Action Action::updateIds() if (firstIndexOfSentence < 0) util::myThrow("could not find any token in current sentence"); - for (unsigned int i = firstIndexOfSentence, currentId = 1; i <= config.getStack(0); ++i) + for (int i = firstIndexOfSentence, currentId = 1; i <= lineIndex; ++i) { if (config.isComment(i) || config.isEmptyNode(i)) continue; - if (config.getLastNotEmptyHypConst(Config::EOSColName, i) == Config::EOSSymbol1) - break; - if (config.isMultiwordPredicted(i)) config.getFirstEmpty(Config::idColName, i) = fmt::format("{}-{}", currentId, currentId+config.getMultiwordSizePredicted(i)); else config.getFirstEmpty(Config::idColName, i) = fmt::format("{}", currentId++); + + config.getFirstEmpty(Config::sentIdColName, i) = fmt::format("{}", lastSentId+1); } }; @@ -577,20 +553,12 @@ Action Action::updateIds() return {Type::Write, apply, undo, appliable}; } -Action Action::attach(Object governorObject, int governorIndex, Object dependentObject, int dependentIndex) +Action Action::attach(Config::Object governorObject, int governorIndex, Config::Object dependentObject, int dependentIndex) { auto apply = [governorObject, governorIndex, dependentObject, dependentIndex](Config & config, Action & a) { - int lineIndex = 0; - if (governorObject == Object::Buffer) - lineIndex = config.getWordIndex() + governorIndex; - else - lineIndex = config.getStack(governorIndex); - int depIndex = 0; - if (dependentObject == Object::Buffer) - depIndex = config.getWordIndex() + dependentIndex; - else - depIndex = config.getStack(dependentIndex); + long lineIndex = config.getRelativeWordIndex(governorObject, governorIndex); + long depIndex = config.getRelativeWordIndex(dependentObject, dependentIndex); addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, std::to_string(lineIndex)).apply(config, a); addToHypothesisRelative(Config::childsColName, governorObject, governorIndex, std::to_string(depIndex)).apply(config, a); @@ -604,43 +572,24 @@ Action Action::attach(Object governorObject, int governorIndex, Object dependent auto appliable = [governorObject, governorIndex, dependentObject, dependentIndex](const Config & config, const Action & action) { - int govLineIndex = 0; - if (governorObject == Object::Buffer) - { - govLineIndex = config.getWordIndex() + governorIndex; - if (!config.has(0, govLineIndex, 0)) - return false; - } - else - { - if (!config.hasStack(governorIndex)) - return false; - govLineIndex = config.getStack(governorIndex); - } - - int depLineIndex = 0; - if (dependentObject == Object::Buffer) - { - depLineIndex = config.getWordIndex() + dependentIndex; - if (!config.has(0, depLineIndex, 0)) - return false; - } - else - { - if (!config.hasStack(dependentIndex)) - return false; - depLineIndex = config.getStack(dependentIndex); - } + if (!config.hasRelativeWordIndex(governorObject, governorIndex) or !config.hasRelativeWordIndex(dependentObject, dependentIndex)) + return false; + long govLineIndex = config.getRelativeWordIndex(governorObject, governorIndex); + long depLineIndex = config.getRelativeWordIndex(dependentObject, dependentIndex); if (!config.isTokenPredicted(govLineIndex) or !config.isTokenPredicted(depLineIndex)) return false; + // Check if dep and head belongs to the same sentence + if (config.getAsFeature(Config::sentIdColName, govLineIndex) != config.getAsFeature(Config::sentIdColName, depLineIndex)) + return false; + // Check for cycles while (govLineIndex != depLineIndex) { try { - govLineIndex = std::stoi(config.getLastNotEmptyHypConst(Config::headColName, govLineIndex)); + govLineIndex = std::stoi(config.getAsFeature(Config::headColName, govLineIndex)); } catch(std::exception &) {return true;} } @@ -677,14 +626,3 @@ Action Action::split(int index) return {Type::Write, apply, undo, appliable}; } -Action::Object Action::str2object(const std::string & s) -{ - if (s == "b") - return Object::Buffer; - if (s == "s") - return Object::Stack; - - util::myThrow(fmt::format("Invalid object '{}'", s)); - return Object::Buffer; -} - diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp index 8ca5ab3..6632c75 100644 --- a/reading_machine/src/BaseConfig.cpp +++ b/reading_machine/src/BaseConfig.cpp @@ -38,6 +38,11 @@ void BaseConfig::readMCD(std::string_view mcdFilename) colIndex2Name.emplace_back(childsColName); colName2Index.emplace(childsColName, colIndex2Name.size()-1); + if (colName2Index.count(sentIdColName)) + util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcdFilename, sentIdColName)); + colIndex2Name.emplace_back(sentIdColName); + colName2Index.emplace(sentIdColName, colIndex2Name.size()-1); + if (colName2Index.count(EOSColName)) util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcdFilename, EOSColName)); colIndex2Name.emplace_back(EOSColName); diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 1ecf072..9526a87 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -637,6 +637,22 @@ long Config::getRelativeWordIndex(int relativeIndex) const return -1; } +long Config::getRelativeWordIndex(Object object, int relativeIndex) const +{ + if (object == Object::Buffer) + return getRelativeWordIndex(relativeIndex); + + return getStack(relativeIndex); +} + +bool Config::hasRelativeWordIndex(Object object, int relativeIndex) const +{ + if (object == Object::Buffer) + return has(0,getRelativeWordIndex(relativeIndex),0); + + return hasStack(relativeIndex); +} + void Config::setAppliableSplitTransitions(const std::vector<Transition *> & appliableSplitTransitions) { this->appliableSplitTransitions = appliableSplitTransitions; @@ -647,3 +663,14 @@ const std::vector<Transition *> & Config::getAppliableSplitTransitions() const return appliableSplitTransitions; } +Config::Object Config::str2object(const std::string & s) +{ + if (s == "b") + return Object::Buffer; + if (s == "s") + return Object::Stack; + + util::myThrow(fmt::format("Invalid object '{}'", s)); + return Object::Buffer; +} + diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 2900e4d..e172acb 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -104,7 +104,7 @@ void ReadingMachine::readFromFile(std::filesystem::path path) })) util::myThrow("No predictions specified"); - auto restOfFile = std::vector<std::string_view>(lines.begin()+curLine, lines.end()); + auto restOfFile = std::vector<std::string>(lines.begin()+curLine, lines.end()); strategy.reset(new Strategy(restOfFile)); diff --git a/reading_machine/src/Strategy.cpp b/reading_machine/src/Strategy.cpp index b540975..7a82a8b 100644 --- a/reading_machine/src/Strategy.cpp +++ b/reading_machine/src/Strategy.cpp @@ -1,6 +1,6 @@ #include "Strategy.hpp" -Strategy::Strategy(const std::vector<std::string_view> & lines) +Strategy::Strategy(std::vector<std::string> lines) { if (!util::doIfNameMatch(std::regex("Strategy : ((incremental)|(sequential))"), lines[0], [this](auto sm) {type = sm[1] == "sequential" ? Type::Sequential : Type::Incremental;})) @@ -8,6 +8,7 @@ Strategy::Strategy(const std::vector<std::string_view> & lines) for (unsigned int i = 1; i < lines.size(); i++) { + std::replace(lines[i].begin(), lines[i].end(), '\t', ' '); auto splited = util::split(lines[i], ' '); std::pair<std::string, std::string> key; std::string value; diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index b0fe02c..6778bd1 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -17,8 +17,8 @@ Transition::Transition(const std::string & name) [this](auto sm){(initLeft(sm[1]));}}, {std::regex("RIGHT (.+)"), [this](auto sm){(initRight(sm[1]));}}, - {std::regex("EOS"), - [this](auto){initEOS();}}, + {std::regex("EOS b\\.(.+)"), + [this](auto sm){initEOS(std::stoi(sm[1]));}}, {std::regex("NOTHING"), [this](auto){initNothing();}}, {std::regex("IGNORECHAR"), @@ -89,18 +89,14 @@ const std::string & Transition::getName() const void Transition::initWrite(std::string colName, std::string object, std::string index, std::string value) { - auto objectValue = Action::str2object(object); + auto objectValue = Config::str2object(object); int indexValue = std::stoi(index); sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value)); cost = [colName, objectValue, indexValue, value](const Config & config) { - int lineIndex = 0; - if (objectValue == Action::Object::Buffer) - lineIndex = config.getWordIndex() + indexValue; - else - lineIndex = config.getStack(indexValue); + int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); if (config.getConst(colName, lineIndex, 0) == value) return 0; @@ -111,18 +107,14 @@ void Transition::initWrite(std::string colName, std::string object, std::string void Transition::initAdd(std::string colName, std::string object, std::string index, std::string value) { - auto objectValue = Action::str2object(object); + auto objectValue = Config::str2object(object); int indexValue = std::stoi(index); sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value)); cost = [colName, objectValue, indexValue, value](const Config & config) { - int lineIndex = 0; - if (objectValue == Action::Object::Buffer) - lineIndex = config.getWordIndex() + indexValue; - else - lineIndex = config.getStack(indexValue); + int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); auto gold = util::split(config.getConst(colName, lineIndex, 0).get(), '|'); @@ -198,7 +190,7 @@ void Transition::initSplitWord(std::vector<std::string> words) sequence.emplace_back(Action::addLinesIfNeeded(words.size())); sequence.emplace_back(Action::consumeCharacterIndex(consumedWord)); for (unsigned int i = 0; i < words.size(); i++) - sequence.emplace_back(Action::addHypothesisRelative("FORM", Action::Object::Buffer, i, words[i])); + sequence.emplace_back(Action::addHypothesisRelative("FORM", Config::Object::Buffer, i, words[i])); sequence.emplace_back(Action::setMultiwordIds(words.size()-1)); cost = [words](const Config & config) @@ -266,8 +258,8 @@ void Transition::initShift() void Transition::initLeft(std::string label) { - sequence.emplace_back(Action::attach(Action::Object::Buffer, 0, Action::Object::Stack, 0)); - sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Stack, 0, label)); + sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0)); + sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label)); sequence.emplace_back(Action::popStack()); cost = [label](const Config & config) @@ -308,8 +300,8 @@ void Transition::initLeft(std::string label) void Transition::initRight(std::string label) { - sequence.emplace_back(Action::attach(Action::Object::Stack, 0, Action::Object::Buffer, 0)); - sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Buffer, 0, label)); + sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0)); + sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label)); sequence.emplace_back(Action::pushWordIndexOnStack()); cost = [label](const Config & config) @@ -395,40 +387,19 @@ void Transition::initReduce() }; } -void Transition::initEOS() +void Transition::initEOS(int bufferIndex) { - sequence.emplace_back(Action::setRoot()); - sequence.emplace_back(Action::updateIds()); - sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Action::Object::Stack, 0, Config::EOSSymbol1)); - sequence.emplace_back(Action::emptyStack()); + sequence.emplace_back(Action::setRoot(bufferIndex)); + sequence.emplace_back(Action::updateIds(bufferIndex)); + sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Config::Object::Buffer, bufferIndex, Config::EOSSymbol1)); - cost = [](const Config & config) + cost = [bufferIndex](const Config & config) { - if (!config.has(0, config.getStack(0), 0)) - return std::numeric_limits<int>::max(); - - if (!config.isToken(config.getStack(0))) - return std::numeric_limits<int>::max(); - - if (config.getConst(Config::EOSColName, config.getStack(0), 0) != Config::EOSSymbol1) + int lineIndex = config.getRelativeWordIndex(Config::Buffer, bufferIndex); + if (config.getConst(Config::EOSColName, lineIndex, 0) != Config::EOSSymbol1) return std::numeric_limits<int>::max(); - int cost = 0; - - --cost; - for (int i = 0; config.hasStack(i); ++i) - { - if (!config.has(0, config.getStack(i), 0)) - continue; - - auto otherStackIndex = config.getStack(i); - auto otherStackGovPred = config.getAsFeature(Config::headColName, otherStackIndex); - - if (util::isEmpty(otherStackGovPred)) - ++cost; - } - - return cost; + return 0; }; } diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 9c08f68..29485d2 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -43,14 +43,14 @@ class Trainer void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples); - void fillDicts(SubConfig & config); + void fillDicts(SubConfig & config, bool debug); public : Trainer(ReadingMachine & machine, int batchSize); void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); void createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); - void fillDicts(BaseConfig & goldConfig); + void fillDicts(BaseConfig & goldConfig, bool debug); float epoch(bool printAdvancement); float evalOnDev(bool printAdvancement); }; diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 7d8a52b..f9ded4f 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -114,7 +114,7 @@ int MacaonTrain::main() if (machine.dictsAreNew()) { - trainer.fillDicts(goldConfig); + trainer.fillDicts(goldConfig, debug); for (auto & it : machine.getDicts()) { std::size_t originalSize = it.second.size(); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index b928da8..f257f05 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -270,7 +270,7 @@ void Trainer::Examples::addClass(int goldIndex) classes.emplace_back(gold); } -void Trainer::fillDicts(BaseConfig & goldConfig) +void Trainer::fillDicts(BaseConfig & goldConfig, bool debug) { SubConfig config(goldConfig, goldConfig.getNbLines()); @@ -280,13 +280,13 @@ void Trainer::fillDicts(BaseConfig & goldConfig) machine.trainMode(false); machine.setDictsState(Dict::State::Open); - fillDicts(config); + fillDicts(config, debug); for (auto & it : machine.getDicts()) it.second.countOcc(false); } -void Trainer::fillDicts(SubConfig & config) +void Trainer::fillDicts(SubConfig & config, bool debug) { torch::AutoGradMode useGrad(false); @@ -297,6 +297,9 @@ void Trainer::fillDicts(SubConfig & config) while (true) { + if (debug) + config.printForDebug(stderr); + if (machine.hasSplitWordTransitionSet()) config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); @@ -321,6 +324,8 @@ void Trainer::fillDicts(SubConfig & config) config.addToHistory(goldTransition->getName()); auto movement = machine.getStrategy().getMovement(config, goldTransition->getName()); + if (debug) + fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", goldTransition->getName(), movement.first, movement.second); if (movement == Strategy::endMovement) break; -- GitLab