From 7c5c406ede376a1991472c8f4fae7d0e3dc978cb Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sat, 7 Mar 2020 18:19:04 +0100 Subject: [PATCH] Added transitions for tokenizer --- reading_machine/include/Action.hpp | 5 + reading_machine/include/Config.hpp | 3 + reading_machine/include/Transition.hpp | 4 + reading_machine/src/Action.cpp | 121 ++++++++++++++++++++++- reading_machine/src/BaseConfig.cpp | 3 + reading_machine/src/Config.cpp | 14 ++- reading_machine/src/Transition.cpp | 131 ++++++++++++++++++++----- trainer/src/macaon_train.cpp | 2 +- 8 files changed, 253 insertions(+), 30 deletions(-) diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index 994fe9c..7900c7b 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -57,7 +57,12 @@ class Action static Action emptyStack(); static Action setRoot(); static Action updateIds(); + static Action endWord(); static Action attach(Object governorObject, int governorIndex, Object dependentObject, int dependentIndex); + static Action addCurCharToCurWord(); + static Action ignoreCurrentCharacter(); + static Action consumeCharacterIndex(std::string consumed); + static Action setMultiwordId(int multiwordSize); }; #endif diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index 833a2f9..de65870 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -33,6 +33,7 @@ class Config std::vector<String> lines; std::set<std::string> predicted; int lastPoppedStack{-1}; + int currentWordId{0}; protected : @@ -113,6 +114,8 @@ class Config void addPredicted(const std::set<std::string> & predicted); bool isPredicted(const std::string & colName) const; int getLastPoppedStack() const; + int getCurrentWordId() const; + void setCurrentWordId(int currentWordId); }; #endif diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index c7309a6..0c1ccd7 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -24,6 +24,10 @@ class Transition void initReduce(); void initEOS(); void initNothing(); + void initIgnoreChar(); + void initEndWord(); + void initAddCharToWord(); + void initSplitWord(std::vector<std::string> words); public : diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index 997f251..433adad 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -12,7 +12,8 @@ Action Action::addLinesIfNeeded(int nbLines) { auto apply = [nbLines](Config & config, Action &) { - config.addLines(1); + while (!config.has(0, config.getWordIndex()+nbLines, 0)) + config.addLines(1); }; auto undo = [](Config &, Action &) @@ -47,6 +48,53 @@ Action Action::moveWordIndex(int movement) return {Type::MoveWord, apply, undo, appliable}; } +Action Action::setMultiwordId(int multiwordSize) +{ + auto apply = [multiwordSize](Config & config, Action & a) + { + addHypothesisRelative(Config::idColName, Object::Buffer, 0, fmt::format("{}-{}", config.getCurrentWordId()+1, config.getCurrentWordId()+multiwordSize)).apply(config, a); + }; + + auto undo = [](Config & config, Action &) + { + config.getLastNotEmpty(Config::idColName, config.getWordIndex()) = ""; + }; + + auto appliable = [](const Config &, const Action &) + { + return true; + }; + + return {Type::Write, apply, undo, appliable}; +} + +Action Action::consumeCharacterIndex(std::string consumed) +{ + auto apply = [consumed](Config & config, Action &) + { + config.moveCharacterIndex(consumed.size()); + }; + + auto undo = [consumed](Config & config, Action &) + { + config.moveCharacterIndex(-consumed.size()); + }; + + auto appliable = [consumed](const Config & config, const Action &) + { + if (!config.canMoveCharacterIndex(consumed.size())) + return false; + + for (unsigned int i = 0; i < consumed.size(); i++) + if (!config.hasCharacter(config.getCharacterIndex()+i) or config.getLetter(config.getCharacterIndex()+i) != consumed[i]) + return false; + + return true; + }; + + return {Type::MoveWord, apply, undo, appliable}; +} + Action Action::moveCharacterIndex(int movement) { auto apply = [movement](Config & config, Action &) @@ -245,6 +293,31 @@ Action Action::popStack() return {Type::Pop, apply, undo, appliable}; } +Action Action::endWord() +{ + auto apply = [](Config & config, Action & a) + { + config.setCurrentWordId(config.getCurrentWordId()+1); + addHypothesisRelative(Config::idColName, Object::Buffer, 0, std::to_string(config.getCurrentWordId())).apply(config, a); + + if (!config.rawInputOnlySeparatorsLeft() and !config.has(0,config.getWordIndex()+1,0)) + config.addLines(1); + }; + + auto undo = [](Config & config, Action &) + { + config.setCurrentWordId(config.getCurrentWordId()-1); + config.getLastNotEmpty(Config::idColName, config.getWordIndex()) = ""; + }; + + auto appliable = [](const Config & config, const Action &) + { + return !util::isEmpty(config.getAsFeature("FORM", config.getWordIndex())); + }; + + return {Type::Write, apply, undo, appliable}; +} + Action Action::emptyStack() { auto apply = [](Config & config, Action & a) @@ -273,6 +346,52 @@ Action Action::emptyStack() return {Type::Pop, apply, undo, appliable}; } +Action Action::ignoreCurrentCharacter() +{ + auto apply = [](Config & config, Action & a) + { + config.moveCharacterIndex(1); + }; + + auto undo = [](Config & config, Action & a) + { + config.moveCharacterIndex(-1); + }; + + auto appliable = [](const Config & config, const Action &) + { + return config.hasCharacter(config.getCharacterIndex()) and util::isSeparator(config.getLetter(config.getCharacterIndex())) and config.canMoveCharacterIndex(1); + }; + + return {Type::MoveChar, apply, undo, appliable}; +} + +Action Action::addCurCharToCurWord() +{ + auto apply = [](Config & config, Action & a) + { + auto & curWord = config.getLastNotEmptyHyp("FORM", config.getWordIndex()); + curWord = fmt::format("{}{}", curWord, config.getLetter(config.getCharacterIndex())); + }; + + auto undo = [](Config & config, Action & a) + { + auto & curWord = config.getLastNotEmptyHyp("FORM", config.getWordIndex()); + std::string newWord = curWord; + unsigned int nbToPop = fmt::format("{}", config.getLetter(config.getCharacterIndex())).size(); + for (unsigned int i = 0; i < nbToPop; i++) + newWord.pop_back(); + curWord = newWord; + }; + + auto appliable = [](const Config & config, const Action &) + { + return config.hasCharacter(config.getCharacterIndex()) and !util::isSeparator(config.getLetter(config.getCharacterIndex())); + }; + + return {Type::Write, apply, undo, appliable}; +} + Action Action::setRoot() { auto apply = [](Config & config, Action & a) diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp index b197250..b067b91 100644 --- a/reading_machine/src/BaseConfig.cpp +++ b/reading_machine/src/BaseConfig.cpp @@ -155,6 +155,9 @@ BaseConfig::BaseConfig(std::string_view mcdFilename, std::string_view tsvFilenam if (not tsvFilename.empty()) readTSVInput(tsvFilename); + if (!has(0,wordIndex,0)) + addLines(1); + if (isComment(wordIndex)) moveWordIndex(1); } diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 5cc4a88..9af5d8c 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -378,7 +378,7 @@ bool Config::moveWordIndex(int relativeMovement) return false; } } - while (!isToken(wordIndex)); + while (isComment(wordIndex)); nbMovements += relativeMovement > 0 ? 1 : -1; } @@ -397,7 +397,7 @@ bool Config::canMoveWordIndex(int relativeMovement) const if (!has(0,oldVal,0)) return false; } - while (!isToken(oldVal)); + while (isComment(oldVal)); nbMovements += relativeMovement > 0 ? 1 : -1; } @@ -494,3 +494,13 @@ int Config::getLastPoppedStack() const return lastPoppedStack; } +int Config::getCurrentWordId() const +{ + return currentWordId; +} + +void Config::setCurrentWordId(int currentWordId) +{ + this->currentWordId = currentWordId; +} + diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 259e7f4..a0edfc3 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -3,41 +3,53 @@ Transition::Transition(const std::string & name) { - std::regex nameRegex("(<(.+)> )?(.+)"); - std::regex writeRegex("WRITE ([bs])\\.(.+) (.+) (.+)"); - std::regex addRegex("ADD ([bs])\\.(.+) (.+) (.+)"); - std::regex shiftRegex("SHIFT"); - std::regex reduceRegex("REDUCE"); - std::regex leftRegex("LEFT (.+)"); - std::regex rightRegex("RIGHT (.+)"); - std::regex eosRegex("EOS"); - std::regex nothingRegex("NOTHING"); + std::vector<std::pair<std::regex, std::function<void(const std::smatch &)>>> inits + { + {std::regex("WRITE ([bs])\\.(.+) (.+) (.+)"), + [this](auto sm){(initWrite(sm[3], sm[1], sm[2], sm[4]));}}, + {std::regex("ADD ([bs])\\.(.+) (.+) (.+)"), + [this](auto sm){(initAdd(sm[3], sm[1], sm[2], sm[4]));}}, + {std::regex("SHIFT"), + [this](auto){initShift();}}, + {std::regex("REDUCE"), + [this](auto){initReduce();}}, + {std::regex("LEFT (.+)"), + [this](auto sm){(initLeft(sm[1]));}}, + {std::regex("RIGHT (.+)"), + [this](auto sm){(initRight(sm[1]));}}, + {std::regex("EOS"), + [this](auto){initEOS();}}, + {std::regex("NOTHING"), + [this](auto){initNothing();}}, + {std::regex("IGNORECHAR"), + [this](auto){initIgnoreChar();}}, + {std::regex("ENDWORD"), + [this](auto){initEndWord();}}, + {std::regex("ADDCHARTOWORD"), + [this](auto){initAddCharToWord();}}, + {std::regex("SPLITWORD ([^@]+)(:?(:?@[^@]+)+)"), + [this](auto sm) + { + std::vector<std::string> splitRes{sm[1]}; + auto splited = util::split(std::string(sm[2]), '@'); + for (auto & s : splited) + splitRes.emplace_back(s); + initSplitWord(splitRes); + }}, + }; try { - if (!util::doIfNameMatch(nameRegex, name, [this, name](auto sm) + if (!util::doIfNameMatch(std::regex("(<(.+)> )?(.+)"), name, [this, name](auto sm) { this->state = sm[2]; this->name = sm[3]; })) util::myThrow("doesn't match nameRegex"); - if (util::doIfNameMatch(writeRegex, this->name, [this](auto sm){initWrite(sm[3], sm[1], sm[2], sm[4]);})) - return; - if (util::doIfNameMatch(addRegex, this->name, [this](auto sm){initAdd(sm[3], sm[1], sm[2], sm[4]);})) - return; - if (util::doIfNameMatch(shiftRegex, this->name, [this](auto){initShift();})) - return; - if (util::doIfNameMatch(reduceRegex, this->name, [this](auto){initReduce();})) - return; - if (util::doIfNameMatch(leftRegex, this->name, [this](auto sm){initLeft(sm[1]);})) - return; - if (util::doIfNameMatch(rightRegex, this->name, [this](auto sm){initRight(sm[1]);})) - return; - if (util::doIfNameMatch(eosRegex, this->name, [this](auto){initEOS();})) - return; - if (util::doIfNameMatch(nothingRegex, this->name, [this](auto){initNothing();})) - return; + for (auto & it : inits) + if (util::doIfNameMatch(it.first, this->name, it.second)) + return; throw std::invalid_argument("no match"); @@ -128,6 +140,73 @@ void Transition::initNothing() }; } +void Transition::initIgnoreChar() +{ + sequence.emplace_back(Action::ignoreCurrentCharacter()); + + cost = [](const Config &) + { + return 0; + }; +} + +void Transition::initEndWord() +{ + sequence.emplace_back(Action::endWord()); + + cost = [](const Config & config) + { + if (config.getConst("FORM", config.getWordIndex(), 0) == config.getAsFeature("FORM", config.getWordIndex())) + return 0; + return 1; + }; +} + +void Transition::initAddCharToWord() +{ + sequence.emplace_back(Action::addLinesIfNeeded(0)); + sequence.emplace_back(Action::addCurCharToCurWord()); + sequence.emplace_back(Action::moveCharacterIndex(1)); + + cost = [](const Config & config) + { + if (!config.hasCharacter(config.getCharacterIndex())) + return std::numeric_limits<int>::max(); + + auto letter = fmt::format("{}", config.getLetter(config.getCharacterIndex())); + auto & goldWord = config.getConst("FORM", config.getWordIndex(), 0).get(); + auto & curWord = config.getLastNotEmptyConst("FORM", config.getWordIndex()).get(); + if (curWord.size() + letter.size() > goldWord.size()) + return 1; + + for (unsigned int i = 0; i < letter.size(); i++) + if (goldWord[curWord.size()+i] != letter[i]) + return 1; + + return 0; + }; +} + +void Transition::initSplitWord(std::vector<std::string> words) +{ + auto & consumedWord = words[0]; + sequence.emplace_back(Action::addLinesIfNeeded(words.size())); + sequence.emplace_back(Action::consumeCharacterIndex(consumedWord)); + for (unsigned int i = 0; i < words.size(); i++) + sequence.emplace_back(Action::addHypothesisRelative("FORM", Action::Object::Buffer, i, words[i])); + sequence.emplace_back(Action::setMultiwordId(words.size()-1)); + + cost = [words](const Config & config) + { + int cost = 0; + for (unsigned int i = 0; i < words.size(); i++) + if (!config.has("FORM", config.getWordIndex()+i, 0) or config.getConst("FORM", config.getWordIndex()+i, 0) != words[i]) + cost++; + + return cost; + }; +} + void Transition::initShift() { sequence.emplace_back(Action::pushWordIndexOnStack()); diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp index 6025537..e12c74c 100644 --- a/trainer/src/macaon_train.cpp +++ b/trainer/src/macaon_train.cpp @@ -86,7 +86,7 @@ int main(int argc, char * argv[]) ReadingMachine machine(machinePath.string()); BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile); - BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile); + BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); SubConfig config(goldConfig); Trainer trainer(machine); -- GitLab