From d438d3c34b1da1736f040556944c53b510eb90d1 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Thu, 2 Apr 2020 22:44:23 +0200 Subject: [PATCH] added transition split --- decoder/src/Decoder.cpp | 2 ++ reading_machine/include/Action.hpp | 1 + reading_machine/include/Config.hpp | 6 +++++ reading_machine/include/ReadingMachine.hpp | 3 +++ reading_machine/include/Transition.hpp | 1 + reading_machine/include/TransitionSet.hpp | 1 + reading_machine/src/Action.cpp | 29 ++++++++++++++++++++++ reading_machine/src/Config.cpp | 10 ++++++++ reading_machine/src/ReadingMachine.cpp | 10 ++++++++ reading_machine/src/Transition.cpp | 17 +++++++++++++ reading_machine/src/TransitionSet.cpp | 12 ++++++++- torch_modules/include/LSTMNetwork.hpp | 1 + torch_modules/src/LSTMNetwork.cpp | 24 +++++++++++++++--- trainer/src/Trainer.cpp | 2 ++ 14 files changed, 114 insertions(+), 5 deletions(-) diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index a24f6a2..7d563c8 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -24,6 +24,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool if (debug) config.printForDebug(stderr); + config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); + auto dictState = machine.getDict(config.getState()).getState(); auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())).back(); machine.getDict(config.getState()).setState(dictState); diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index 012af61..5561a81 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -64,6 +64,7 @@ class Action static Action ignoreCurrentCharacter(); static Action consumeCharacterIndex(util::utf8string consumed); static Action setMultiwordIds(int multiwordSize); + static Action split(int index); }; #endif diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index 1e6833e..0bc46f5 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -9,6 +9,8 @@ #include "util.hpp" #include "Dict.hpp" +class Transition; + class Config { public : @@ -20,6 +22,7 @@ class Config static constexpr const char * deprelColName = "DEPREL"; static constexpr const char * idColName = "ID"; static constexpr int nbHypothesesMax = 1; + static constexpr int maxNbAppliableSplitTransitions = 3; public : @@ -34,6 +37,7 @@ class Config std::set<std::string> predicted; int lastPoppedStack{-1}; int currentWordId{0}; + std::vector<Transition *> appliableSplitTransitions; protected : @@ -124,6 +128,8 @@ class Config void setCurrentWordId(int currentWordId); void addMissingColumns(); void addComment(); + void setAppliableSplitTransitions(const std::vector<Transition *> & appliableSplitTransitions); + const std::vector<Transition *> & getAppliableSplitTransitions() const; }; #endif diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index cc56ec2..dd0098d 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -26,6 +26,8 @@ class ReadingMachine std::map<std::string, Dict> dicts; std::set<std::string> predicted; + std::unique_ptr<TransitionSet> splitWordTransitionSet; + private : void readFromFile(std::filesystem::path path); @@ -36,6 +38,7 @@ class ReadingMachine ReadingMachine(std::filesystem::path path); ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts); TransitionSet & getTransitionSet(); + TransitionSet & getSplitWordTransitionSet(); Strategy & getStrategy(); Dict & getDict(const std::string & state); std::map<std::string, Dict> & getDicts(); diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index 0c1ccd7..87531db 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -28,6 +28,7 @@ class Transition void initEndWord(); void initAddCharToWord(); void initSplitWord(std::vector<std::string> words); + void initSplit(int index); public : diff --git a/reading_machine/include/TransitionSet.hpp b/reading_machine/include/TransitionSet.hpp index 2daaa2e..a1bc2c1 100644 --- a/reading_machine/include/TransitionSet.hpp +++ b/reading_machine/include/TransitionSet.hpp @@ -17,6 +17,7 @@ class TransitionSet TransitionSet(const std::string & filename); std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c); Transition * getBestAppliableTransition(const Config & c); + std::vector<Transition *> getNAppliableTransitions(const Config & c, int n); std::size_t getTransitionIndex(const Transition * transition) const; Transition * getTransition(std::size_t index); Transition * getTransition(const std::string & name); diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index e425740..889aaa7 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -1,4 +1,5 @@ #include "Action.hpp" +#include "Transition.hpp" Action::Action(Action::Type type, std::function<void(Config & config, Action & action)> apply, std::function<void(Config & config, Action & action)> undo, std::function<bool(const Config & config, const Action & action)> appliable) { @@ -619,6 +620,33 @@ Action Action::attach(Object governorObject, int governorIndex, Object dependent return {Type::Write, apply, undo, appliable}; } +Action Action::split(int index) +{ + auto apply = [index](Config & config, Action &) + { + Transition * t = config.getAppliableSplitTransitions()[index]; + t->apply(config); + }; + + auto undo = [](Config &, Action &) + { + //TODO : undo this + }; + + auto appliable = [index](const Config & config, const Action &) + { + auto & transitions = config.getAppliableSplitTransitions(); + + if (index < 0 or index >= (int)transitions.size()) + return false; + + Transition * t = transitions[index]; + return t->appliable(config); + }; + + return {Type::Write, apply, undo, appliable}; +} + Action::Object Action::str2object(const std::string & s) { if (s == "b") @@ -629,3 +657,4 @@ Action::Object Action::str2object(const std::string & s) util::myThrow(fmt::format("Invalid object '{}'", s)); return Object::Buffer; } + diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 6aec0fb..bb2366c 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -601,3 +601,13 @@ long Config::getRelativeWordIndex(int relativeIndex) const return -1; } +void Config::setAppliableSplitTransitions(const std::vector<Transition *> & appliableSplitTransitions) +{ + this->appliableSplitTransitions = appliableSplitTransitions; +} + +const std::vector<Transition *> & Config::getAppliableSplitTransitions() const +{ + return appliableSplitTransitions; +} + diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 50ce655..96b5e71 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -63,6 +63,11 @@ void ReadingMachine::readFromFile(std::filesystem::path path) --curLine; + util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine++], [this,path](auto sm) + { + this->splitWordTransitionSet.reset(new TransitionSet(path.parent_path() / sm.str(1))); + }); + if (!util::doIfNameMatch(std::regex("Predictions : (.+)"), lines[curLine++], [this](auto sm) { auto predictions = sm.str(1); @@ -84,6 +89,11 @@ TransitionSet & ReadingMachine::getTransitionSet() return classifier->getTransitionSet(); } +TransitionSet & ReadingMachine::getSplitWordTransitionSet() +{ + return *splitWordTransitionSet; +} + Strategy & ReadingMachine::getStrategy() { return *strategy; diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 0a9bb89..d45a794 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -27,6 +27,8 @@ Transition::Transition(const std::string & name) [this](auto){initEndWord();}}, {std::regex("ADDCHARTOWORD"), [this](auto){initAddCharToWord();}}, + {std::regex("SPLIT (.+)"), + [this](auto sm){(initSplit(std::stoi(sm.str(1))));}}, {std::regex("SPLITWORD ([^@]+)(:?(:?@[^@]+)+)"), [this](auto sm) { @@ -210,6 +212,21 @@ void Transition::initSplitWord(std::vector<std::string> words) }; } +void Transition::initSplit(int index) +{ + sequence.emplace_back(Action::split(index)); + + cost = [index](const Config & config) + { + auto & transitions = config.getAppliableSplitTransitions(); + + if (index < 0 or index >= (int)transitions.size()) + return std::numeric_limits<int>::max(); + + return transitions[index]->getCost(config); + }; +} + void Transition::initShift() { sequence.emplace_back(Action::pushWordIndexOnStack()); diff --git a/reading_machine/src/TransitionSet.cpp b/reading_machine/src/TransitionSet.cpp index d5b9716..8fbb6a5 100644 --- a/reading_machine/src/TransitionSet.cpp +++ b/reading_machine/src/TransitionSet.cpp @@ -42,6 +42,17 @@ std::vector<std::pair<Transition*, int>> TransitionSet::getAppliableTransitionsC return appliableTransitions; } +std::vector<Transition *> TransitionSet::getNAppliableTransitions(const Config & c, int n) +{ + std::vector<Transition *> result; + + for (unsigned int i = 0; i < transitions.size() && result.size() < n; i++) + if (transitions[i].appliable(c)) + result.emplace_back(&transitions[i]); + + return result; +} + Transition * TransitionSet::getBestAppliableTransition(const Config & c) { Transition * result = nullptr; @@ -94,4 +105,3 @@ Transition * TransitionSet::getTransition(const std::string & name) return nullptr; } - diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp index 0e1ad1c..8c14e76 100644 --- a/torch_modules/include/LSTMNetwork.hpp +++ b/torch_modules/include/LSTMNetwork.hpp @@ -22,6 +22,7 @@ class LSTMNetworkImpl : public NeuralNetworkImpl torch::nn::Linear linear2{nullptr}; torch::nn::LSTM contextLSTM{nullptr}; torch::nn::LSTM rawInputLSTM{nullptr}; + torch::nn::LSTM splitTransLSTM{nullptr}; std::vector<torch::nn::LSTM> lstms; public : diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp index 0ae4cd1..201feea 100644 --- a/torch_modules/src/LSTMNetwork.cpp +++ b/torch_modules/src/LSTMNetwork.cpp @@ -1,4 +1,5 @@ #include "LSTMNetwork.hpp" +#include "Transition.hpp" LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput) { @@ -29,8 +30,9 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std:: lstmDropout = register_module("lstm_dropout", torch::nn::Dropout(0.3)); hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3)); contextLSTM = register_module("contextLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(columns.size()*embeddingsSize, contextLSTMSize).batch_first(true).bidirectional(true))); + splitTransLSTM = register_module("splitTransLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, embeddingsSize).batch_first(true).bidirectional(true))); - int totalLSTMOutputSize = contextLSTM->options.hidden_size() * (contextLSTM->options.bidirectional() ? 4 : 1) + rawInputLSTMOutputSize; + int totalLSTMOutputSize = contextLSTM->options.hidden_size() * (contextLSTM->options.bidirectional() ? 4 : 1) + rawInputLSTMOutputSize + (Config::maxNbAppliableSplitTransitions * splitTransLSTM->options.hidden_size() * (splitTransLSTM->options.bidirectional() ? 2 : 1)); for (auto & col : focusedColumns) { @@ -49,21 +51,28 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input) auto embeddings = embeddingsDropout(wordEmbeddings(input)); - auto context = embeddings.narrow(1, rawInputSize, getContextSize()); + auto splitTrans = embeddings.narrow(1, 0, Config::maxNbAppliableSplitTransitions); + + auto context = embeddings.narrow(1, splitTrans.size(1)+rawInputSize, getContextSize()); context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}); - auto elementsEmbeddings = embeddings.narrow(1, rawInputSize+context.size(1), input.size(1)-(rawInputSize+context.size(1))); + auto elementsEmbeddings = embeddings.narrow(1, splitTrans.size(1)+rawInputSize+context.size(1), input.size(1)-(splitTrans.size(1)+rawInputSize+context.size(1))); std::vector<torch::Tensor> lstmOutputs; if (rawInputSize != 0) { - auto rawLetters = embeddings.narrow(1, 0, rawInputSize); + auto rawLetters = embeddings.narrow(1, splitTrans.size(1), rawInputSize); auto lstmOut = rawInputLSTM(rawLetters).output; lstmOutputs.emplace_back(lstmOut.reshape({lstmOut.size(0), -1})); } + { + auto lstmOut = splitTransLSTM(splitTrans).output; + lstmOutputs.emplace_back(lstmOut.reshape({lstmOut.size(0), -1})); + } + auto curIndex = 0; for (unsigned int i = 0; i < focusedColumns.size(); i++) { @@ -101,6 +110,13 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, std::vector<std::vector<long>> context; context.emplace_back(); + auto & splitTransitions = config.getAppliableSplitTransitions(); + for (int i = 0; i < Config::maxNbAppliableSplitTransitions; i++) + if (i < (int)splitTransitions.size()) + context.back().emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName())); + else + context.back().emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + if (rawInputSize > 0) { for (int i = 0; i < leftWindowRawInput; i++) diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index ccbad4e..621441c 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -41,6 +41,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch: if (debug) config.printForDebug(stderr); + config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); + auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); if (!transition) { -- GitLab