From eaeb5040d8f044a825e4368359c1570dcfd82e1f Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 18 Sep 2019 18:04:53 +0200 Subject: [PATCH] Added tokenizer for training --- maca_common/include/util.hpp | 2 + maca_common/src/util.cpp | 30 ++++++ trainer/src/TrainInfos.cpp | 4 + trainer/src/Trainer.cpp | 25 +---- transition_machine/include/ActionBank.hpp | 18 ++++ transition_machine/include/Config.hpp | 12 +++ transition_machine/src/ActionBank.cpp | 106 ++++++++++++++++++++++ transition_machine/src/Config.cpp | 76 ++++++++++++++-- transition_machine/src/FeatureBank.cpp | 40 ++++++++ transition_machine/src/Oracle.cpp | 71 ++++++++++++++- 10 files changed, 354 insertions(+), 30 deletions(-) diff --git a/maca_common/include/util.hpp b/maca_common/include/util.hpp index 6770d97..7833af1 100644 --- a/maca_common/include/util.hpp +++ b/maca_common/include/util.hpp @@ -204,7 +204,9 @@ float getRandomValueInRange(int range); int getNbLines(const std::string & filename); int getStartIndexOfNthSymbol(const std::string & s, int n); +int getStartIndexOfNthSymbolFrom(const std::string::iterator & s, const std::string::iterator & end, int n); int getEndIndexOfNthSymbol(const std::string & s, int n); +int getEndIndexOfNthSymbolFrom(const std::string::iterator & s, const std::string::iterator & end, int n); unsigned int getNbSymbols(const std::string & s); std::string shrinkString(const std::string & base, int maxSize, const std::string token); diff --git a/maca_common/src/util.cpp b/maca_common/src/util.cpp index ea26bd6..8682941 100644 --- a/maca_common/src/util.cpp +++ b/maca_common/src/util.cpp @@ -451,6 +451,26 @@ int getStartIndexOfNthSymbol(const std::string & s, int n) return it - s.begin(); } +int getStartIndexOfNthSymbolFrom(const std::string::iterator & s, const std::string::iterator & end, int n) +{ + if (n >= 0) + { + auto it = s; + for (int i = 0; i < n; i++) + try {utf8::next(it, end);} + catch (utf8::not_enough_room &) {return -1;} + + return it - s; + } + + auto it = s; + for (int i = 0; i < -n; i++) + try {utf8::prior(it, end);} + catch (utf8::not_enough_room &) {return 1;} + + return it - s; +} + int getEndIndexOfNthSymbol(const std::string & s, int n) { auto it = s.begin(); @@ -461,6 +481,16 @@ int getEndIndexOfNthSymbol(const std::string & s, int n) return (it-1) - s.begin(); } +int getEndIndexOfNthSymbolFrom(const std::string::iterator & s, const std::string::iterator & end, int n) +{ + auto it = s; + for (int i = 0; i < n+1; i++) + try {utf8::next(it, end);} + catch (utf8::not_enough_room &) {return i == n ? end - s - 1 : -1;} + + return (it-1) - s; +} + unsigned int getNbSymbols(const std::string & s) { return utf8::distance(s.begin(), s.end()); diff --git a/trainer/src/TrainInfos.cpp b/trainer/src/TrainInfos.cpp index aa84078..8677127 100644 --- a/trainer/src/TrainInfos.cpp +++ b/trainer/src/TrainInfos.cpp @@ -161,6 +161,8 @@ void TrainInfos::computeTrainScores(Config & c) addTrainScore(it.first, computeScoreOnTapes(c, {"GOV", "LABEL"}, 0, c.getHead())); else if (it.first == "Tagger") addTrainScore(it.first, computeScoreOnTapes(c, {"POS"}, 0, c.getHead())); + else if (it.first == "Tokenizer") + addTrainScore(it.first, computeScoreOnTapes(c, {"FORM"}, 0, c.getHead())); else if (it.first == "Morpho") addTrainScore(it.first, computeScoreOnTapes(c, {"MORPHO"}, 0, c.getHead())); else if (it.first == "Lemmatizer_Rules") @@ -183,6 +185,8 @@ void TrainInfos::computeDevScores(Config & c) addDevScore(it.first, computeScoreOnTapes(c, {"GOV", "LABEL"}, 0, c.getHead())); else if (it.first == "Parser") addDevScore(it.first, computeScoreOnTapes(c, {"GOV", "LABEL"}, 0, c.getHead())); + else if (it.first == "Tokenizer") + addDevScore(it.first, computeScoreOnTapes(c, {"FORM"}, 0, c.getHead())); else if (it.first == "Tagger") addDevScore(it.first, computeScoreOnTapes(c, {"POS"}, 0, c.getHead())); else if (it.first == "Morpho") diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 56a2a49..988838f 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -239,7 +239,6 @@ void Trainer::doStepTrain() std::string pAction = ""; std::string oAction = ""; - bool pActionIsZeroCost = false; std::string actionName = ""; float loss = 0.0; @@ -253,19 +252,10 @@ void Trainer::doStepTrain() for (auto & it : weightedActions) if (it.first) - { if (pAction == "") pAction = it.second.second; - if (tm.getCurrentClassifier()->getActionCost(trainConfig, it.second.second) == 0) - { - oAction = it.second.second; - break; - } - } - - if (pAction == oAction) - pActionIsZeroCost = true; + oAction = tm.getCurrentClassifier()->getZeroCostActions(trainConfig)[0]; } else { @@ -311,17 +301,8 @@ void Trainer::doStepTrain() } else { - if (pActionIsZeroCost) - { - actionName = pAction; - TI.lastActionWasPredicted[trainConfig.getCurrentStateName()] = true; - } - else - { - actionName = oAction; - TI.lastActionWasPredicted[trainConfig.getCurrentStateName()] = false; - } - + actionName = oAction; + TI.lastActionWasPredicted[trainConfig.getCurrentStateName()] = oAction == pAction; } if (ProgramParameters::debug) diff --git a/transition_machine/include/ActionBank.hpp b/transition_machine/include/ActionBank.hpp index bb44dc6..8fe5817 100644 --- a/transition_machine/include/ActionBank.hpp +++ b/transition_machine/include/ActionBank.hpp @@ -80,6 +80,10 @@ class ActionBank /// @param relativeIndex The index of the column that will be read and written into, relatively to the head of the Config. static void writeRuleResult(Config & config, const std::string & fromTapeName, const std::string & targetTapeName, const std::string & rule, int relativeIndex); + static void addCharToBuffer(Config & config, const std::string & tapeName, int relativeIndex); + + static void removeCharFromBuffer(Config & config, const std::string & tapeName, int relativeIndex); + /// \brief Write something on the buffer /// /// \param tapeName The tape we will write to @@ -96,6 +100,20 @@ class ActionBank /// \return A BasicAction moving the head static Action::BasicAction moveHead(int movement); + /// \brief Move the raw input head + /// + /// \param movement The relative movement of the raw input head + /// + /// \return A BasicAction moving the head + static Action::BasicAction moveRawInputHead(int movement); + + /// \brief Verify if rawInput begins with word + /// + /// \param word The word to verify + /// + /// \return A BasicAction only appliable if word is the prefix of rawInput. + static Action::BasicAction rawInputBeginsWith(std::string word); + /// \brief Write something on the buffer /// /// \param tapeName The tape we will write to diff --git a/transition_machine/include/Config.hpp b/transition_machine/include/Config.hpp index fc7eefc..9d16f90 100644 --- a/transition_machine/include/Config.hpp +++ b/transition_machine/include/Config.hpp @@ -181,6 +181,12 @@ class Config LimitedStack< std::pair<std::string, Action> > pastActions; /// @brief The last action that have been undone. std::pair<std::string, int> lastUndoneAction; + /// @brief The input before tokenization. + std::string rawInput; + /// @brief Head of the raw input. + int rawInputHead; + /// @brief Index of the rawInputHead in term of bytes. + int rawInputHeadIndex; public : @@ -231,6 +237,10 @@ class Config /// /// @param mvt The relative increment in the position of the head. void moveHead(int mvt); + /// @brief Move the rawInputHead of this Config. + /// + /// @param mvt The relative increment in the position of the rawInputHead. + void moveRawInputHead(int mvt); /// @brief Whether or not this Config is terminal. /// /// A Config is terminal when the head is at the end of the multi-tapes buffer and the stack is empty. @@ -340,6 +350,8 @@ class Config /// /// @return True if the head is at the end of the tapes. bool endOfTapes() const; + /// @brief Update rawInput according to the tape TEXT. + void updateRawInput(); /// @brief Set the output file. void setOutputFile(FILE * outputFile); /// @brief Print the cells that have not been printed. diff --git a/transition_machine/src/ActionBank.cpp b/transition_machine/src/ActionBank.cpp index 8d886a8..22bdde4 100644 --- a/transition_machine/src/ActionBank.cpp +++ b/transition_machine/src/ActionBank.cpp @@ -17,6 +17,43 @@ Action::BasicAction ActionBank::moveHead(int movement) return basicAction; } +Action::BasicAction ActionBank::moveRawInputHead(int movement) +{ + auto apply = [movement](Config & c, Action::BasicAction &) + {c.moveRawInputHead(movement);}; + auto undo = [movement](Config & c, Action::BasicAction &) + {c.moveRawInputHead(-movement);}; + auto appliable = [movement](Config & c, Action::BasicAction &) + {return c.rawInputHeadIndex+movement < (int)c.rawInput.size();}; + Action::BasicAction basicAction = + {Action::BasicAction::Type::MoveHead, "", apply, undo, appliable}; + + return basicAction; +} + +Action::BasicAction ActionBank::rawInputBeginsWith(std::string word) +{ + auto apply = [](Config &, Action::BasicAction &) + {}; + auto undo = [](Config &, Action::BasicAction &) + {}; + auto appliable = [word](Config & c, Action::BasicAction &) + { + if (c.rawInputHeadIndex+word.size() >= c.rawInput.size()) + return false; + + for (unsigned int i = 0; i < word.size(); i++) + if (c.rawInput[c.rawInputHeadIndex+i] != word[i]) + return false; + + return true; + }; + Action::BasicAction basicAction = + {Action::BasicAction::Type::Write, "", apply, undo, appliable}; + + return basicAction; +} + Action::BasicAction ActionBank::bufferWrite(std::string tapeName, std::string value, int relativeIndex) { auto apply = [tapeName, value, relativeIndex](Config & c, Action::BasicAction &) @@ -192,6 +229,50 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na sequence.emplace_back(moveHead(movement)); } + else if(std::string(b1) == "IGNORECHAR") + { + sequence.emplace_back(moveRawInputHead(1)); + } + else if(std::string(b1) == "ENDWORD") + { + } + else if(std::string(b1) == "ADDCHARTOWORD") + { + auto apply = [](Config & c, Action::BasicAction &) + {addCharToBuffer(c, "FORM", 0);}; + auto undo = [](Config & c, Action::BasicAction &) + {removeCharFromBuffer(c, "FORM", 0);}; + auto appliable = [](Config & , Action::BasicAction &) + {return true;}; + Action::BasicAction basicAction = + {Action::BasicAction::Type::Write, "", apply, undo, appliable}; + + sequence.emplace_back(basicAction); + sequence.emplace_back(moveRawInputHead(1)); + } + else if(std::string(b1) == "SPLITWORD") + { + if (sscanf(name.c_str(), "SPLITWORD %s", b2) != 1) + invalidNameAndAbort(ERRINFO); + + auto splited = split(b2, '@'); + int nbSymbols = getNbSymbols(splited[0]); + + sequence.emplace_back(rawInputBeginsWith(splited[0])); + + sequence.emplace_back(moveRawInputHead(nbSymbols)); + + for (unsigned int i = 1; i < splited.size(); i++) + sequence.emplace_back(bufferWrite("FORM", splited[i], i-1)); + } + else if(std::string(b1) == "MOVERAW") + { + int movement; + if (sscanf(name.c_str(), "MOVERAW %d", &movement) != 1) + invalidNameAndAbort(ERRINFO); + + sequence.emplace_back(moveRawInputHead(movement)); + } else if(std::string(b1) == "ERROR") { auto apply = [](Config &, Action::BasicAction &) @@ -675,6 +756,31 @@ void ActionBank::writeRuleResult(Config & config, const std::string & fromTapeNa toTape.setHyp(relativeIndex, applyRule(from, rule)); } +void ActionBank::addCharToBuffer(Config & config, const std::string & tapeName, int relativeIndex) +{ + auto & tape = config.getTape(tapeName); + auto & from = tape.getHyp(relativeIndex); + + int nbChar = getEndIndexOfNthSymbolFrom(config.rawInput.begin()+config.rawInputHeadIndex,config.rawInput.end(), 0)+1; + + std::string suffix = std::string(config.rawInput.begin()+config.rawInputHeadIndex, config.rawInput.begin()+config.rawInputHeadIndex+nbChar); + + tape.setHyp(relativeIndex, from+suffix); +} + +void ActionBank::removeCharFromBuffer(Config & config, const std::string & tapeName, int relativeIndex) +{ + auto & tape = config.getTape(tapeName); + auto from = tape.getRef(relativeIndex); + + std::string suffix = std::string(config.rawInput.begin()+config.rawInputHeadIndex, config.rawInput.begin()+config.rawInputHeadIndex+getEndIndexOfNthSymbolFrom(config.rawInput.begin()+config.rawInputHeadIndex,config.rawInput.end(), 0)); + + for (char c : suffix) + from.pop_back(); + + tape.setHyp(relativeIndex, from); +} + int ActionBank::getLinkLength(const Config & c, const std::string & action) { auto splitted = split(action, ' '); diff --git a/transition_machine/src/Config.cpp b/transition_machine/src/Config.cpp index 10fe33b..69e37ea 100644 --- a/transition_machine/src/Config.cpp +++ b/transition_machine/src/Config.cpp @@ -10,9 +10,11 @@ Config::Config(BD & bd, const std::string inputFilename) : bd(bd), hashHistory(H { this->outputFile = nullptr; this->stackHistory = -1; - this->inputFilename = inputFilename; this->lastIndexPrinted = -1; + this->inputFilename = inputFilename; head = 0; + rawInputHead = 0; + rawInputHeadIndex = 0; inputAllRead = false; for(int i = 0; i < bd.getNbLines(); i++) tapes.emplace_back(bd.getNameOfLine(i), bd.lineIsKnown(i)); @@ -31,6 +33,9 @@ Config::Config(const Config & other) : bd(other.bd), hashHistory(other.hashHisto this->lastIndexPrinted = other.lastIndexPrinted; this->tapes = other.tapes; this->totalEntropy = other.totalEntropy; + this->rawInputHead = other.rawInputHead; + this->rawInputHeadIndex = other.rawInputHeadIndex; + this->rawInput = other.rawInput; this->inputFilename = other.inputFilename; this->inputAllRead = other.inputAllRead; @@ -137,6 +142,9 @@ void Config::readInput() tape.addToHyp(""); } } + + if (hasTape("TEXT")) + updateRawInput(); } void Config::printForDebug(FILE * output) @@ -174,6 +182,21 @@ void Config::printForDebug(FILE * output) for(int i = 0; i < 80; i++) fprintf(output, "-%s", i == 80-1 ? "\n" : ""); + if (!rawInput.empty()) + { + int rawWindow = 30; + int relativeHeadIndex = getEndIndexOfNthSymbolFrom(rawInput.begin()+rawInputHeadIndex, rawInput.end(), rawWindow); + auto endIter = rawInput.begin() + rawInputHeadIndex + relativeHeadIndex + 1; + if (relativeHeadIndex < 0) + endIter = rawInput.end(); + + std::string toPrint(rawInput.begin()+rawInputHeadIndex, endIter); + fprintf(stderr, "%s\n", toPrint.c_str()); + + for(int i = 0; i < 80; i++) + fprintf(output, "-%s", i == 80-1 ? "\n" : ""); + } + printColumns(output, cols, 3); fprintf(output, "Stack : "); @@ -227,6 +250,28 @@ void Config::moveHead(int mvt) } } +void Config::moveRawInputHead(int mvt) +{ + if (mvt >= 0) + { + int relativeIndexMvt = getStartIndexOfNthSymbolFrom(rawInput.begin()+rawInputHeadIndex, rawInput.end(), mvt); + if (relativeIndexMvt > 0) + { + rawInputHead += mvt; + rawInputHeadIndex += relativeIndexMvt; + } + } + else + { + int relativeIndexMvt = getStartIndexOfNthSymbolFrom(rawInput.begin()+rawInputHeadIndex, rawInput.begin(), mvt); + if (relativeIndexMvt < 0) + { + rawInputHeadIndex += relativeIndexMvt; + rawInputHead += mvt; + } + } +} + bool Config::isFinal() { return endOfTapes() && stack.empty(); @@ -248,6 +293,8 @@ void Config::reset() inputAllRead = false; head = 0; + rawInputHead = 0; + rawInputHeadIndex = 0; file.reset(); while (tapes[0].size() < ProgramParameters::readSize*4 && !inputAllRead) @@ -327,16 +374,17 @@ LimitedStack<float> & Config::getCurrentStateEntropyHistory() void Config::shuffle(const std::string & delimiterTape, const std::string & delimiter) { - std::vector< std::pair<unsigned int, unsigned int> > delimiters; + struct Trio{unsigned int a; unsigned int b; unsigned int c; Trio(unsigned int a, unsigned int b, unsigned int c): a(a), b(b), c(c){}}; + std::vector<Trio> delimiters; if (delimiterTape == "0") { unsigned int previousIndex = 0; for (int i = 0; i < tapes[0].refSize(); i++) { - delimiters.emplace_back(previousIndex, i); + delimiters.emplace_back(previousIndex, i, delimiters.size()); previousIndex = i+1; - } + } } else { @@ -345,7 +393,7 @@ void Config::shuffle(const std::string & delimiterTape, const std::string & deli for (int i = 0; i < tape.refSize(); i++) if (tape.getRef(i-head) == delimiter) { - delimiters.emplace_back(previousIndex, i); + delimiters.emplace_back(previousIndex, i, delimiters.size()); previousIndex = i+1; } } @@ -356,7 +404,7 @@ void Config::shuffle(const std::string & delimiterTape, const std::string & deli return; } - std::pair<unsigned int, unsigned int> suffix = {delimiters.back().second+1, tapes[0].refSize()-1}; + std::pair<unsigned int, unsigned int> suffix = {delimiters.back().b+1, tapes[0].refSize()-1}; std::random_shuffle(delimiters.begin(), delimiters.end()); @@ -367,13 +415,16 @@ void Config::shuffle(const std::string & delimiterTape, const std::string & deli newTapes[tape].clearDataForCopy(); for (auto & delimiter : delimiters) - newTapes[tape].copyPart(tapes[tape], delimiter.first, delimiter.second+1); + newTapes[tape].copyPart(tapes[tape], delimiter.a, delimiter.b+1); if (suffix.first <= suffix.second) newTapes[tape].copyPart(tapes[tape], suffix.first, suffix.second+1); } tapes = newTapes; + + if (!rawInput.empty()) + updateRawInput(); } int Config::stackGetElem(int index) const @@ -658,3 +709,14 @@ float Config::Tape::getScore(int from, int to) return 100.0*res / (1+to-from); } +void Config::updateRawInput() +{ + rawInput = ""; + auto & textTape = getTape("TEXT"); + for (int i = 0; i < textTape.size(); i++) + { + if (textTape[i] != "_") + rawInput += (rawInput.empty() ? std::string("") : std::string(" ")) + textTape[i]; + } +} + diff --git a/transition_machine/src/FeatureBank.cpp b/transition_machine/src/FeatureBank.cpp index ab40482..ccef787 100644 --- a/transition_machine/src/FeatureBank.cpp +++ b/transition_machine/src/FeatureBank.cpp @@ -212,6 +212,46 @@ FeatureModel::FeatureValue getDistance(int index1, int index2, const std::string std::function<FeatureModel::FeatureValue(Config &)> FeatureBank::str2func(const std::string & s) { + if (split(s,'.')[0] == "raw") + { + int relativeIndex; + try {relativeIndex = std::stoi(split(s, '.')[1]);} + catch (std::exception &) + { + fprintf(stderr, "ERROR (%s) : invalid feature format \'%s\'. Relative index must be an integer. Aborting.\n", ERRINFO, s.c_str()); + exit(1); + } + return [relativeIndex, s](Config & c) + { + int relativeCharIndex = getStartIndexOfNthSymbolFrom(c.rawInput.begin()+c.rawInputHeadIndex, relativeIndex >= 0 ? c.rawInput.end() : c.rawInput.begin(), relativeIndex); + + Dict * dict = Dict::getDict("letters"); + auto policy = dictPolicy2FeaturePolicy(dict->policy); + + if (relativeCharIndex >= 0 && relativeIndex < 0) + return FeatureModel::FeatureValue({dict, s, Dict::nullValueStr, policy}); + if (relativeCharIndex < 0 && relativeIndex >= 0) + return FeatureModel::FeatureValue({dict, s, Dict::nullValueStr, policy}); + + int endIndex = getEndIndexOfNthSymbolFrom(c.rawInput.begin()+c.rawInputHeadIndex+relativeCharIndex, c.rawInput.end(), 0); + + auto a = c.rawInput.begin()+c.rawInputHeadIndex+relativeCharIndex; + auto b = a + endIndex + 1; + + std::string value; + + if (a <= b) + value = std::string(a,b); + else + value = std::string(b,a); + + if (value.empty()) + return FeatureModel::FeatureValue({dict, s, Dict::nullValueStr, policy}); + + return FeatureModel::FeatureValue({dict, s, value, policy}); + }; + } + auto splited = split(s, '#'); if (splited.size() == 1) diff --git a/transition_machine/src/Oracle.cpp b/transition_machine/src/Oracle.cpp index aeaf7e0..769684e 100644 --- a/transition_machine/src/Oracle.cpp +++ b/transition_machine/src/Oracle.cpp @@ -211,7 +211,40 @@ void Oracle::createDatabase() }, [](Config & c, Oracle *, const std::string & action) { - return (action == "WRITE b.0 BIO " + c.getTape("BIO").getRef(0) || c.endOfTapes()) ? 0 : 1; + auto & currentWordRef = c.getTape("FORM").getRef(0); + auto & currentWordHyp = c.getTape("FORM").getHyp(0); + + auto splited = split(split(action, ' ').back(),'@'); + + if (splited.size() > 2) + { + if (c.rawInput.begin() + splited[0].size() >= c.rawInput.end()) + return 1; + + for (unsigned int i = 0; i < splited[0].size(); i++) + if (splited[0][i] != c.rawInput[c.rawInputHeadIndex+i]) + return 1; + + for (unsigned int i = 1; i < splited.size(); i++) + if (c.getTape("FORM").getRef(i-1) != splited[i]) + return 1; + + return 0; + } + + if (currentWordRef == currentWordHyp) + if (action == "ENDWORD") + return 0; + + if (action == "ADDCHARTOWORD" && currentWordRef.size() > currentWordHyp.size()) + { + for (unsigned int i = 0; i < (currentWordRef.size()-currentWordHyp.size()); i++) + if (currentWordRef[currentWordHyp.size()+i] != c.rawInput[c.rawInputHeadIndex+i]) + return 1; + return 0; + } + + return 1; }))); str2oracle.emplace("eos", std::unique_ptr<Oracle>(new Oracle( @@ -288,6 +321,42 @@ void Oracle::createDatabase() return 0; }))); + str2oracle.emplace("strategy_tokenizer,tagger", std::unique_ptr<Oracle>(new Oracle( + [](Oracle *) + { + }, + [](Config & c, Oracle *) + { + if (c.pastActions.size() == 0) + return std::string("MOVE tokenizer 0"); + + std::string previousState = noAccentLower(c.pastActions.getElem(0).first); + std::string previousAction = noAccentLower(c.pastActions.getElem(0).second.name); + std::string newState; + int movement = 0; + + if (previousState == "signature") + newState = "tagger"; + else if (previousState == "tokenizer") + { + if (split(previousAction, ' ')[0] == "splitword" || split(previousAction, ' ')[0] == "endword") + newState = "signature"; + else + newState = "tokenizer"; + } + else if (previousState == "tagger" || previousState == "error_tagger") + { + newState = "tokenizer"; + movement = 1; + } + + return "MOVE " + newState + " " + std::to_string(movement); + }, + [](Config &, Oracle *, const std::string &) + { + return 0; + }))); + str2oracle.emplace("strategy_parser", std::unique_ptr<Oracle>(new Oracle( [](Oracle *) { -- GitLab