From 8adb34622a1cd429004d55b514321b9adf57de04 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 8 Apr 2020 00:50:45 +0200 Subject: [PATCH] code refactoring --- reading_machine/src/Classifier.cpp | 50 ----- torch_modules/include/CNN.hpp | 2 +- torch_modules/include/CNNNetwork.hpp | 35 ---- torch_modules/include/ConcatWordsNetwork.hpp | 21 -- torch_modules/include/ContextLSTM.hpp | 30 +++ .../include/DepthLayerTreeEmbedding.hpp | 25 +++ torch_modules/include/FocusedColumnLSTM.hpp | 28 +++ torch_modules/include/LSTM.hpp | 6 +- torch_modules/include/LSTMNetwork.hpp | 23 +-- torch_modules/include/MLP.hpp | 14 -- torch_modules/include/NeuralNetwork.hpp | 23 +-- torch_modules/include/RLTNetwork.hpp | 31 --- torch_modules/include/RandomNetwork.hpp | 1 + torch_modules/include/RawInputLSTM.hpp | 26 +++ torch_modules/include/SplitTransLSTM.hpp | 26 +++ torch_modules/include/Submodule.hpp | 22 ++ torch_modules/src/CNN.cpp | 2 +- torch_modules/src/CNNNetwork.cpp | 187 ----------------- torch_modules/src/ConcatWordsNetwork.cpp | 25 --- torch_modules/src/ContextLSTM.cpp | 62 ++++++ torch_modules/src/DepthLayerTreeEmbedding.cpp | 17 ++ torch_modules/src/FocusedColumnLSTM.cpp | 94 +++++++++ torch_modules/src/LSTM.cpp | 2 +- torch_modules/src/LSTMNetwork.cpp | 104 ++++------ torch_modules/src/MLP.cpp | 8 - torch_modules/src/NeuralNetwork.cpp | 191 ------------------ torch_modules/src/RLTNetwork.cpp | 190 ----------------- torch_modules/src/RandomNetwork.cpp | 10 +- torch_modules/src/RawInputLSTM.cpp | 40 ++++ torch_modules/src/SplitTransLSTM.cpp | 33 +++ torch_modules/src/Submodule.cpp | 7 + 31 files changed, 479 insertions(+), 856 deletions(-) delete mode 100644 torch_modules/include/CNNNetwork.hpp delete mode 100644 torch_modules/include/ConcatWordsNetwork.hpp create mode 100644 torch_modules/include/ContextLSTM.hpp create mode 100644 torch_modules/include/DepthLayerTreeEmbedding.hpp create mode 100644 torch_modules/include/FocusedColumnLSTM.hpp delete mode 100644 torch_modules/include/MLP.hpp delete mode 100644 torch_modules/include/RLTNetwork.hpp create mode 100644 torch_modules/include/RawInputLSTM.hpp create mode 100644 torch_modules/include/SplitTransLSTM.hpp create mode 100644 torch_modules/include/Submodule.hpp delete mode 100644 torch_modules/src/CNNNetwork.cpp delete mode 100644 torch_modules/src/ConcatWordsNetwork.cpp create mode 100644 torch_modules/src/ContextLSTM.cpp create mode 100644 torch_modules/src/DepthLayerTreeEmbedding.cpp create mode 100644 torch_modules/src/FocusedColumnLSTM.cpp delete mode 100644 torch_modules/src/MLP.cpp delete mode 100644 torch_modules/src/RLTNetwork.cpp create mode 100644 torch_modules/src/RawInputLSTM.cpp create mode 100644 torch_modules/src/SplitTransLSTM.cpp create mode 100644 torch_modules/src/Submodule.cpp diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 54dccff..20a5056 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -1,8 +1,5 @@ #include "Classifier.hpp" #include "util.hpp" -#include "ConcatWordsNetwork.hpp" -#include "RLTNetwork.hpp" -#include "CNNNetwork.hpp" #include "LSTMNetwork.hpp" #include "RandomNetwork.hpp" @@ -40,45 +37,6 @@ void Classifier::initNeuralNetwork(const std::string & topology) this->nn.reset(new RandomNetworkImpl(this->transitionSet->size())); } }, - { - std::regex("ConcatWords\\(\\{(.*)\\},\\{(.*)\\}\\)"), - "ConcatWords({bufferContext},{stackContext}) : Concatenate embeddings of words in context.", - [this,topology](auto sm) - { - std::vector<int> bufferContext, stackContext; - for (auto s : util::split(sm.str(1), ',')) - bufferContext.emplace_back(std::stoi(s)); - for (auto s : util::split(sm.str(2), ',')) - stackContext.emplace_back(std::stoi(s)); - this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size(), bufferContext, stackContext)); - } - }, - { - std::regex("CNN\\(([+\\-]?\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"), - "CNN(unknownValueThreshold,{bufferContext},{stackContext},{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.", - [this,topology](auto sm) - { - std::vector<int> focusedBuffer, focusedStack, maxNbElements, bufferContext, stackContext; - std::vector<std::string> focusedColumns, columns; - for (auto s : util::split(sm.str(2), ',')) - bufferContext.emplace_back(std::stoi(s)); - for (auto s : util::split(sm.str(3), ',')) - stackContext.emplace_back(std::stoi(s)); - for (auto s : util::split(sm.str(4), ',')) - columns.emplace_back(s); - for (auto s : util::split(sm.str(5), ',')) - focusedBuffer.push_back(std::stoi(s)); - for (auto s : util::split(sm.str(6), ',')) - focusedStack.push_back(std::stoi(s)); - for (auto s : util::split(sm.str(7), ',')) - focusedColumns.emplace_back(s); - for (auto s : util::split(sm.str(8), ',')) - maxNbElements.push_back(std::stoi(s)); - if (focusedColumns.size() != maxNbElements.size()) - util::myThrow("focusedColumns.size() != maxNbElements.size()"); - this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm.str(9)), std::stoi(sm.str(10)))); - } - }, { std::regex("LSTM\\(([+\\-]?\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"), "LSTM(unknownValueThreshold,{bufferContext},{stackContext},{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.", @@ -105,14 +63,6 @@ void Classifier::initNeuralNetwork(const std::string & topology) this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm.str(9)), std::stoi(sm.str(10)))); } }, - { - std::regex("RLT\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"), - "RLT(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.", - [this,topology](auto sm) - { - this->nn.reset(new RLTNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), std::stoi(sm.str(2)), std::stoi(sm.str(3)))); - } - }, }; std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology); diff --git a/torch_modules/include/CNN.hpp b/torch_modules/include/CNN.hpp index 509be85..66c405c 100644 --- a/torch_modules/include/CNN.hpp +++ b/torch_modules/include/CNN.hpp @@ -17,7 +17,7 @@ class CNNImpl : public torch::nn::Module CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize); torch::Tensor forward(torch::Tensor input); - int getOutputSize(); + std::size_t getOutputSize(); }; TORCH_MODULE(CNN); diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp deleted file mode 100644 index 6fb985a..0000000 --- a/torch_modules/include/CNNNetwork.hpp +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef CNNNETWORK__H -#define CNNNETWORK__H - -#include "NeuralNetwork.hpp" -#include "CNN.hpp" - -class CNNNetworkImpl : public NeuralNetworkImpl -{ - private : - - int unknownValueThreshold; - std::vector<std::string> focusedColumns; - std::vector<int> maxNbElements; - int leftWindowRawInput; - int rightWindowRawInput; - int rawInputSize; - - torch::nn::Embedding wordEmbeddings{nullptr}; - torch::nn::Dropout embeddingsDropout{nullptr}; - torch::nn::Dropout cnnDropout{nullptr}; - torch::nn::Dropout hiddenDropout{nullptr}; - torch::nn::Linear linear1{nullptr}; - torch::nn::Linear linear2{nullptr}; - CNN contextCNN{nullptr}; - CNN rawInputCNN{nullptr}; - std::vector<CNN> cnns; - - public : - - CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput); - torch::Tensor forward(torch::Tensor input) override; - std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; -}; - -#endif diff --git a/torch_modules/include/ConcatWordsNetwork.hpp b/torch_modules/include/ConcatWordsNetwork.hpp deleted file mode 100644 index b2c9691..0000000 --- a/torch_modules/include/ConcatWordsNetwork.hpp +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef CONCATWORDSNETWORK__H -#define CONCATWORDSNETWORK__H - -#include "NeuralNetwork.hpp" - -class ConcatWordsNetworkImpl : public NeuralNetworkImpl -{ - private : - - torch::nn::Embedding wordEmbeddings{nullptr}; - torch::nn::Linear linear1{nullptr}; - torch::nn::Linear linear2{nullptr}; - torch::nn::Dropout dropout{nullptr}; - - public : - - ConcatWordsNetworkImpl(int nbOutputs, const std::vector<int> & bufferContext, const std::vector<int> & stackContext); - torch::Tensor forward(torch::Tensor input) override; -}; - -#endif diff --git a/torch_modules/include/ContextLSTM.hpp b/torch_modules/include/ContextLSTM.hpp new file mode 100644 index 0000000..136029c --- /dev/null +++ b/torch_modules/include/ContextLSTM.hpp @@ -0,0 +1,30 @@ +#ifndef CONTEXTLSTM__H +#define CONTEXTLSTM__H + +#include <torch/torch.h> +#include "Submodule.hpp" +#include "LSTM.hpp" + +class ContextLSTMImpl : public torch::nn::Module, public Submodule +{ + private : + + LSTM lstm{nullptr}; + std::vector<std::string> columns; + std::vector<int> bufferContext; + std::vector<int> stackContext; + int unknownValueThreshold; + std::vector<std::string> unknownValueColumns{"FORM", "LEMMA"}; + + public : + + ContextLSTMImpl(std::vector<std::string> columns, int embeddingsSize, int outEmbeddingsSize, std::vector<int> bufferContext, std::vector<int> stackContext, LSTMImpl::LSTMOptions options, int unknownValueThreshold); + torch::Tensor forward(torch::Tensor input); + std::size_t getOutputSize() override; + std::size_t getInputSize() override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; +}; +TORCH_MODULE(ContextLSTM); + +#endif + diff --git a/torch_modules/include/DepthLayerTreeEmbedding.hpp b/torch_modules/include/DepthLayerTreeEmbedding.hpp new file mode 100644 index 0000000..d471e6b --- /dev/null +++ b/torch_modules/include/DepthLayerTreeEmbedding.hpp @@ -0,0 +1,25 @@ +#ifndef DEPTHLAYERTREEEMBEDDING__H +#define DEPTHLAYERTREEEMBEDDING__H + +#include <torch/torch.h> +#include "fmt/core.h" +#include "LSTM.hpp" + +class DepthLayerTreeEmbeddingImpl : public torch::nn::Module +{ + private : + + std::vector<LSTM> depthLstm; + int maxDepth; + int maxElemPerDepth; + + public : + + DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth); + torch::Tensor forward(torch::Tensor input); + int getOutputSize(); +}; +TORCH_MODULE(DepthLayerTreeEmbedding); + +#endif + diff --git a/torch_modules/include/FocusedColumnLSTM.hpp b/torch_modules/include/FocusedColumnLSTM.hpp new file mode 100644 index 0000000..6ea836a --- /dev/null +++ b/torch_modules/include/FocusedColumnLSTM.hpp @@ -0,0 +1,28 @@ +#ifndef FOCUSEDCOLUMNLSTM__H +#define FOCUSEDCOLUMNLSTM__H + +#include <torch/torch.h> +#include "Submodule.hpp" +#include "LSTM.hpp" + +class FocusedColumnLSTMImpl : public torch::nn::Module, public Submodule +{ + private : + + LSTM lstm{nullptr}; + std::vector<int> focusedBuffer, focusedStack; + std::string column; + int maxNbElements; + + public : + + FocusedColumnLSTMImpl(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options); + torch::Tensor forward(torch::Tensor input); + std::size_t getOutputSize() override; + std::size_t getInputSize() override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; +}; +TORCH_MODULE(FocusedColumnLSTM); + +#endif + diff --git a/torch_modules/include/LSTM.hpp b/torch_modules/include/LSTM.hpp index eb06c45..c45cb9f 100644 --- a/torch_modules/include/LSTM.hpp +++ b/torch_modules/include/LSTM.hpp @@ -6,6 +6,10 @@ class LSTMImpl : public torch::nn::Module { + public : + + using LSTMOptions = std::tuple<bool,bool,int,float,bool>; + private : torch::nn::LSTM lstm{nullptr}; @@ -13,7 +17,7 @@ class LSTMImpl : public torch::nn::Module public : - LSTMImpl(int inputSize, int outputSize, std::tuple<bool,bool,int,float,bool> options); + LSTMImpl(int inputSize, int outputSize, LSTMOptions options); torch::Tensor forward(torch::Tensor input); int getOutputSize(int sequenceLength); }; diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp index 860f407..5762ad1 100644 --- a/torch_modules/include/LSTMNetwork.hpp +++ b/torch_modules/include/LSTMNetwork.hpp @@ -2,29 +2,28 @@ #define LSTMNETWORK__H #include "NeuralNetwork.hpp" -#include "LSTM.hpp" +#include "ContextLSTM.hpp" +#include "RawInputLSTM.hpp" +#include "SplitTransLSTM.hpp" +#include "FocusedColumnLSTM.hpp" class LSTMNetworkImpl : public NeuralNetworkImpl { private : - int unknownValueThreshold; - std::vector<std::string> focusedColumns; - std::vector<int> maxNbElements; - int leftWindowRawInput; - int rightWindowRawInput; - int rawInputSize; - torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Dropout embeddingsDropout{nullptr}; torch::nn::Dropout lstmDropout{nullptr}; torch::nn::Dropout hiddenDropout{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; - LSTM contextLSTM{nullptr}; - LSTM rawInputLSTM{nullptr}; - LSTM splitTransLSTM{nullptr}; - std::vector<LSTM> lstms; + + ContextLSTM contextLSTM{nullptr}; + RawInputLSTM rawInputLSTM{nullptr}; + SplitTransLSTM splitTransLSTM{nullptr}; + std::vector<FocusedColumnLSTM> focusedLstms; + + bool hasRawInputLSTM{false}; public : diff --git a/torch_modules/include/MLP.hpp b/torch_modules/include/MLP.hpp deleted file mode 100644 index 90bde50..0000000 --- a/torch_modules/include/MLP.hpp +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef MLP__H -#define MLP__H - -#include <torch/torch.h> - -class MLPImpl : torch::nn::Module -{ - public : - - MLPImpl(const std::string & topology); -}; -TORCH_MODULE(MLP); - -#endif diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index ffc0052..be25c87 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -15,31 +15,10 @@ class NeuralNetworkImpl : public torch::nn::Module static constexpr int maxNbEmbeddings = 150000; - std::vector<std::string> columns{"FORM"}; - std::vector<int> bufferContext{-3,-2,-1,0,1}; - std::vector<int> stackContext{}; - std::vector<int> bufferFocused{}; - std::vector<int> stackFocused{}; - - protected : - - void setBufferContext(const std::vector<int> & bufferContext); - void setStackContext(const std::vector<int> & stackContext); - void setBufferFocused(const std::vector<int> & bufferFocused); - void setStackFocused(const std::vector<int> & stackFocused); - public : virtual torch::Tensor forward(torch::Tensor input) = 0; - virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const; - std::vector<long> extractContextIndexes(const Config & config) const; - std::vector<long> extractFocusedIndexes(const Config & config) const; - int getContextSize() const; - void setColumns(const std::vector<std::string> & columns); - void addAppliableSplitTransitions(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const; - void addRawInput(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, int leftWindowRawInput, int rightWindowRawInput) const; - void addContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & contextIndexes, int unknownValueThreshold, const std::vector<std::string> & unknownValueColumns) const; - void addFocused(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & focusedIndexes, const std::vector<std::string> & focusedColumns, const std::vector<int> & maxNbElements) const; + virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0; }; TORCH_MODULE(NeuralNetwork); diff --git a/torch_modules/include/RLTNetwork.hpp b/torch_modules/include/RLTNetwork.hpp deleted file mode 100644 index 71b8aa5..0000000 --- a/torch_modules/include/RLTNetwork.hpp +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef RLTNETWORK__H -#define RLTNETWORK__H - -#include "NeuralNetwork.hpp" - -class RLTNetworkImpl : public NeuralNetworkImpl -{ - private : - - static constexpr long maxNbChilds{8}; - static inline std::vector<long> focusedBufferIndexes{0,1,2}; - static inline std::vector<long> focusedStackIndexes{0,1}; - - int leftBorder, rightBorder; - - torch::nn::Embedding wordEmbeddings{nullptr}; - torch::nn::Linear linear1{nullptr}; - torch::nn::Linear linear2{nullptr}; - torch::nn::LSTM vectorBiLSTM{nullptr}; - torch::nn::LSTM treeLSTM{nullptr}; - torch::Tensor S; - torch::Tensor nullTree; - - public : - - RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements); - torch::Tensor forward(torch::Tensor input) override; - std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; -}; - -#endif diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp index e715cc4..8f58d7b 100644 --- a/torch_modules/include/RandomNetwork.hpp +++ b/torch_modules/include/RandomNetwork.hpp @@ -13,6 +13,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl RandomNetworkImpl(long outputSize); torch::Tensor forward(torch::Tensor input) override; + std::vector<std::vector<long>> extractContext(Config &, Dict &) const override; }; #endif diff --git a/torch_modules/include/RawInputLSTM.hpp b/torch_modules/include/RawInputLSTM.hpp new file mode 100644 index 0000000..db17d6f --- /dev/null +++ b/torch_modules/include/RawInputLSTM.hpp @@ -0,0 +1,26 @@ +#ifndef RAWINPUTLSTM__H +#define RAWINPUTLSTM__H + +#include <torch/torch.h> +#include "Submodule.hpp" +#include "LSTM.hpp" + +class RawInputLSTMImpl : public torch::nn::Module, public Submodule +{ + private : + + LSTM lstm{nullptr}; + int leftWindow, rightWindow; + + public : + + RawInputLSTMImpl(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options); + torch::Tensor forward(torch::Tensor input); + std::size_t getOutputSize() override; + std::size_t getInputSize() override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; +}; +TORCH_MODULE(RawInputLSTM); + +#endif + diff --git a/torch_modules/include/SplitTransLSTM.hpp b/torch_modules/include/SplitTransLSTM.hpp new file mode 100644 index 0000000..f90c0ed --- /dev/null +++ b/torch_modules/include/SplitTransLSTM.hpp @@ -0,0 +1,26 @@ +#ifndef SPLITTRANSLSTM__H +#define SPLITTRANSLSTM__H + +#include <torch/torch.h> +#include "Submodule.hpp" +#include "LSTM.hpp" + +class SplitTransLSTMImpl : public torch::nn::Module, public Submodule +{ + private : + + LSTM lstm{nullptr}; + int maxNbTrans; + + public : + + SplitTransLSTMImpl(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options); + torch::Tensor forward(torch::Tensor input); + std::size_t getOutputSize() override; + std::size_t getInputSize() override; + void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; +}; +TORCH_MODULE(SplitTransLSTM); + +#endif + diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp new file mode 100644 index 0000000..437bbfa --- /dev/null +++ b/torch_modules/include/Submodule.hpp @@ -0,0 +1,22 @@ +#ifndef SUBMODULE__H +#define SUBMODULE__H + +#include "Dict.hpp" +#include "Config.hpp" + +class Submodule +{ + protected : + + std::size_t firstInputIndex{0}; + + public : + + void setFirstInputIndex(std::size_t firstInputIndex); + virtual std::size_t getOutputSize() = 0; + virtual std::size_t getInputSize() = 0; + virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const = 0; +}; + +#endif + diff --git a/torch_modules/src/CNN.cpp b/torch_modules/src/CNN.cpp index cc67d60..dbc3797 100644 --- a/torch_modules/src/CNN.cpp +++ b/torch_modules/src/CNN.cpp @@ -26,7 +26,7 @@ torch::Tensor CNNImpl::forward(torch::Tensor input) return cnnOut; } -int CNNImpl::getOutputSize() +std::size_t CNNImpl::getOutputSize() { return windowSizes.size()*nbFilters; } diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp deleted file mode 100644 index c03b347..0000000 --- a/torch_modules/src/CNNNetwork.cpp +++ /dev/null @@ -1,187 +0,0 @@ -#include "CNNNetwork.hpp" - -CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput) -{ - constexpr int embeddingsSize = 64; - constexpr int hiddenSize = 1024; - constexpr int nbFiltersContext = 512; - constexpr int nbFiltersFocused = 64; - - setBufferContext(bufferContext); - setStackContext(stackContext); - setColumns(columns); - setBufferFocused(focusedBufferIndexes); - setStackFocused(focusedStackIndexes); - - rawInputSize = leftWindowRawInput + rightWindowRawInput + 1; - if (leftWindowRawInput < 0 or rightWindowRawInput < 0) - rawInputSize = 0; - else - rawInputCNN = register_module("rawInputCNN", CNN(std::vector<int>{2,3,4}, nbFiltersFocused, embeddingsSize)); - int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize(); - - wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); - embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3)); - cnnDropout = register_module("cnn_dropout", torch::nn::Dropout(0.3)); - hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3)); - contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize)); - int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize; - for (auto & col : focusedColumns) - { - std::vector<int> windows{2,3,4}; - cnns.emplace_back(register_module(fmt::format("CNN_{}", col), CNN(windows, nbFiltersFocused, embeddingsSize))); - totalCnnOutputSize += cnns.back()->getOutputSize() * (focusedBufferIndexes.size()+focusedStackIndexes.size()); - } - linear1 = register_module("linear1", torch::nn::Linear(totalCnnOutputSize, hiddenSize)); - linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); -} - -torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) -{ - if (input.dim() == 1) - input = input.unsqueeze(0); - - auto embeddings = embeddingsDropout(wordEmbeddings(input)); - - auto context = embeddings.narrow(1, rawInputSize, getContextSize()); - context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}); - - auto elementsEmbeddings = embeddings.narrow(1, rawInputSize+context.size(1), input.size(1)-(rawInputSize+context.size(1))); - - std::vector<torch::Tensor> cnnOutputs; - - if (rawInputSize != 0) - { - auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1); - cnnOutputs.emplace_back(rawInputCNN(rawLetters.unsqueeze(1))); - } - - auto curIndex = 0; - for (unsigned int i = 0; i < focusedColumns.size(); i++) - { - long nbElements = maxNbElements[i]; - for (unsigned int focused = 0; focused < bufferFocused.size()+stackFocused.size(); focused++) - { - auto cnnInput = elementsEmbeddings.narrow(1, curIndex, nbElements).unsqueeze(1); - curIndex += nbElements; - cnnOutputs.emplace_back(cnns[i](cnnInput)); - } - } - - cnnOutputs.emplace_back(contextCNN(context.unsqueeze(1))); - - auto totalInput = cnnDropout(torch::cat(cnnOutputs, 1)); - - return linear2(hiddenDropout(torch::relu(linear1(totalInput)))); -} - -std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const -{ - if (dict.size() >= maxNbEmbeddings) - util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings)); - - std::vector<long> contextIndexes = extractContextIndexes(config); - std::vector<std::vector<long>> context; - context.emplace_back(); - - if (rawInputSize > 0) - { - for (int i = 0; i < leftWindowRawInput; i++) - if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i)) - context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); - else - context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); - - for (int i = 0; i <= rightWindowRawInput; i++) - if (config.hasCharacter(config.getCharacterIndex()+i)) - context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)))); - else - context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); - } - - for (auto index : contextIndexes) - for (auto & col : columns) - if (index == -1) - for (auto & contextElement : context) - contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); - else - { - int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index)); - - for (auto & contextElement : context) - contextElement.push_back(dictIndex); - - if (is_training()) - if (col == "FORM" || col == "LEMMA") - if (dict.getNbOccs(dictIndex) <= unknownValueThreshold) - { - context.emplace_back(context.back()); - context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr); - } - } - - std::vector<long> focusedIndexes = extractFocusedIndexes(config); - - for (auto & contextElement : context) - for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++) - { - auto & col = focusedColumns[colIndex]; - - for (auto index : focusedIndexes) - { - if (index == -1) - { - for (int i = 0; i < maxNbElements[colIndex]; i++) - contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); - continue; - } - - std::vector<std::string> elements; - if (col == "FORM") - { - auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get()); - - for (int i = 0; i < maxNbElements[colIndex]; i++) - if (i < (int)asUtf8.size()) - elements.emplace_back(fmt::format("{}", asUtf8[i])); - else - elements.emplace_back(Dict::nullValueStr); - } - else if (col == "FEATS") - { - auto splited = util::split(config.getAsFeature(col, index).get(), '|'); - - for (int i = 0; i < maxNbElements[colIndex]; i++) - if (i < (int)splited.size()) - elements.emplace_back(fmt::format("FEATS({})", splited[i])); - else - elements.emplace_back(Dict::nullValueStr); - } - else if (col == "ID") - { - if (config.isTokenPredicted(index)) - elements.emplace_back("ID(TOKEN)"); - else if (config.isMultiwordPredicted(index)) - elements.emplace_back("ID(MULTIWORD)"); - else if (config.isEmptyNodePredicted(index)) - elements.emplace_back("ID(EMPTYNODE)"); - } - else - { - elements.emplace_back(config.getAsFeature(col, index)); - } - - if ((int)elements.size() != maxNbElements[colIndex]) - util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col)); - - for (auto & element : elements) - contextElement.emplace_back(dict.getIndexOrInsert(element)); - } - } - - if (!is_training() && context.size() > 1) - util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size())); - - return context; -} - diff --git a/torch_modules/src/ConcatWordsNetwork.cpp b/torch_modules/src/ConcatWordsNetwork.cpp deleted file mode 100644 index 2331d59..0000000 --- a/torch_modules/src/ConcatWordsNetwork.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include "ConcatWordsNetwork.hpp" - -ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, const std::vector<int> & bufferContext, const std::vector<int> & stackContext) -{ - constexpr int embeddingsSize = 64; - constexpr int hiddenSize = 500; - - setBufferContext(bufferContext); - setStackContext(stackContext); - setColumns({"FORM", "UPOS"}); - - wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); - linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, hiddenSize)); - linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); - dropout = register_module("dropout", torch::nn::Dropout(0.3)); -} - -torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input) -{ - if (input.dim() == 1) - input = input.unsqueeze(0); - auto wordsAsEmb = dropout(wordEmbeddings(input).view({input.size(0), -1})); - return linear2(torch::relu(linear1(wordsAsEmb))); -} - diff --git a/torch_modules/src/ContextLSTM.cpp b/torch_modules/src/ContextLSTM.cpp new file mode 100644 index 0000000..95daa69 --- /dev/null +++ b/torch_modules/src/ContextLSTM.cpp @@ -0,0 +1,62 @@ +#include "ContextLSTM.hpp" + +ContextLSTMImpl::ContextLSTMImpl(std::vector<std::string> columns, int embeddingsSize, int outEmbeddingsSize, std::vector<int> bufferContext, std::vector<int> stackContext, LSTMImpl::LSTMOptions options, int unknownValueThreshold) : columns(columns), bufferContext(bufferContext), stackContext(stackContext), unknownValueThreshold(unknownValueThreshold) +{ + lstm = register_module("lstm", LSTM(columns.size()*embeddingsSize, outEmbeddingsSize, options)); +} + +std::size_t ContextLSTMImpl::getOutputSize() +{ + return lstm->getOutputSize(bufferContext.size()+stackContext.size()); +} + +std::size_t ContextLSTMImpl::getInputSize() +{ + return columns.size()*(bufferContext.size()+stackContext.size()); +} + +void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +{ + std::vector<long> contextIndexes; + + for (int index : bufferContext) + contextIndexes.emplace_back(config.getRelativeWordIndex(index)); + + for (int index : stackContext) + if (config.hasStack(index)) + contextIndexes.emplace_back(config.getStack(index)); + else + contextIndexes.emplace_back(-1); + + for (auto index : contextIndexes) + for (auto & col : columns) + if (index == -1) + for (auto & contextElement : context) + contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + else + { + int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index)); + + for (auto & contextElement : context) + contextElement.push_back(dictIndex); + + if (is_training()) + for (auto & targetCol : unknownValueColumns) + if (col == targetCol) + if (dict.getNbOccs(dictIndex) <= unknownValueThreshold) + { + context.emplace_back(context.back()); + context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr); + } + } +} + +torch::Tensor ContextLSTMImpl::forward(torch::Tensor input) +{ + auto context = input.narrow(1, firstInputIndex, getInputSize()); + + context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*context.size(2)}); + + return lstm(context); +} + diff --git a/torch_modules/src/DepthLayerTreeEmbedding.cpp b/torch_modules/src/DepthLayerTreeEmbedding.cpp new file mode 100644 index 0000000..d53a04a --- /dev/null +++ b/torch_modules/src/DepthLayerTreeEmbedding.cpp @@ -0,0 +1,17 @@ +#include "DepthLayerTreeEmbedding.hpp" + +DepthLayerTreeEmbeddingImpl::DepthLayerTreeEmbeddingImpl(int maxDepth, int maxElemPerDepth) : maxDepth(maxDepth), maxElemPerDepth(maxElemPerDepth) +{ + +} + +torch::Tensor DepthLayerTreeEmbeddingImpl::forward(torch::Tensor input) +{ + +} + +int DepthLayerTreeEmbeddingImpl::getOutputSize() +{ + +} + diff --git a/torch_modules/src/FocusedColumnLSTM.cpp b/torch_modules/src/FocusedColumnLSTM.cpp new file mode 100644 index 0000000..9b5f52f --- /dev/null +++ b/torch_modules/src/FocusedColumnLSTM.cpp @@ -0,0 +1,94 @@ +#include "FocusedColumnLSTM.hpp" + +FocusedColumnLSTMImpl::FocusedColumnLSTMImpl(std::vector<int> focusedBuffer, std::vector<int> focusedStack, std::string column, int maxNbElements, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : focusedBuffer(focusedBuffer), focusedStack(focusedStack), column(column), maxNbElements(maxNbElements) +{ + lstm = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options)); +} + +torch::Tensor FocusedColumnLSTMImpl::forward(torch::Tensor input) +{ + std::vector<torch::Tensor> outputs; + for (unsigned int i = 0; i < focusedBuffer.size()+focusedStack.size(); i++) + outputs.emplace_back(lstm(input.narrow(1, firstInputIndex+i*maxNbElements, maxNbElements))); + + return torch::cat(outputs, 1); +} + +std::size_t FocusedColumnLSTMImpl::getOutputSize() +{ + return (focusedBuffer.size()+focusedStack.size())*lstm->getOutputSize(maxNbElements); +} + +std::size_t FocusedColumnLSTMImpl::getInputSize() +{ + return (focusedBuffer.size()+focusedStack.size()) * maxNbElements; +} + +void FocusedColumnLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +{ + std::vector<long> focusedIndexes; + + for (int index : focusedBuffer) + focusedIndexes.emplace_back(config.getRelativeWordIndex(index)); + + for (int index : focusedStack) + if (config.hasStack(index)) + focusedIndexes.emplace_back(config.getStack(index)); + else + focusedIndexes.emplace_back(-1); + + for (auto & contextElement : context) + { + for (auto index : focusedIndexes) + { + if (index == -1) + { + for (int i = 0; i < maxNbElements; i++) + contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + continue; + } + + std::vector<std::string> elements; + if (column == "FORM") + { + auto asUtf8 = util::splitAsUtf8(config.getAsFeature(column, index).get()); + + for (int i = 0; i < maxNbElements; i++) + if (i < (int)asUtf8.size()) + elements.emplace_back(fmt::format("{}", asUtf8[i])); + else + elements.emplace_back(Dict::nullValueStr); + } + else if (column == "FEATS") + { + auto splited = util::split(config.getAsFeature(column, index).get(), '|'); + + for (int i = 0; i < maxNbElements; i++) + if (i < (int)splited.size()) + elements.emplace_back(fmt::format("FEATS({})", splited[i])); + else + elements.emplace_back(Dict::nullValueStr); + } + else if (column == "ID") + { + if (config.isTokenPredicted(index)) + elements.emplace_back("ID(TOKEN)"); + else if (config.isMultiwordPredicted(index)) + elements.emplace_back("ID(MULTIWORD)"); + else if (config.isEmptyNodePredicted(index)) + elements.emplace_back("ID(EMPTYNODE)"); + } + else + { + elements.emplace_back(config.getAsFeature(column, index)); + } + + if ((int)elements.size() != maxNbElements) + util::myThrow(fmt::format("elements.size ({}) != maxNbElements ({})", elements.size(), maxNbElements)); + + for (auto & element : elements) + contextElement.emplace_back(dict.getIndexOrInsert(element)); + } + } +} + diff --git a/torch_modules/src/LSTM.cpp b/torch_modules/src/LSTM.cpp index 58b102a..b8f8e7f 100644 --- a/torch_modules/src/LSTM.cpp +++ b/torch_modules/src/LSTM.cpp @@ -1,6 +1,6 @@ #include "LSTM.hpp" -LSTMImpl::LSTMImpl(int inputSize, int outputSize, std::tuple<bool,bool,int,float,bool> options) : outputAll(std::get<4>(options)) +LSTMImpl::LSTMImpl(int inputSize, int outputSize, LSTMOptions options) : outputAll(std::get<4>(options)) { auto lstmOptions = torch::nn::LSTMOptions(inputSize, outputSize) .batch_first(std::get<0>(options)) diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp index 268b504..430b226 100644 --- a/torch_modules/src/LSTMNetwork.cpp +++ b/torch_modules/src/LSTMNetwork.cpp @@ -1,6 +1,6 @@ #include "LSTMNetwork.hpp" -LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput) +LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) { constexpr int embeddingsSize = 256; constexpr int hiddenSize = 8192; @@ -8,41 +8,45 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std:: constexpr int focusedLSTMSize = 256; constexpr int rawInputLSTMSize = 32; - std::tuple<bool,bool,int,float,bool> lstmOptions{true,true,2,0.3,false}; + LSTMImpl::LSTMOptions lstmOptions{true,true,2,0.3,false}; auto lstmOptionsAll = lstmOptions; std::get<4>(lstmOptionsAll) = true; - setBufferContext(bufferContext); - setStackContext(stackContext); - setColumns(columns); - setBufferFocused(focusedBufferIndexes); - setStackFocused(focusedStackIndexes); - - rawInputSize = leftWindowRawInput + rightWindowRawInput + 1; - int rawInputLSTMOutSize = 0; - if (leftWindowRawInput < 0 or rightWindowRawInput < 0) - rawInputSize = 0; - else + int currentOutputSize = embeddingsSize; + int currentInputSize = 1; + + contextLSTM = register_module("contextLSTM", ContextLSTM(columns, embeddingsSize, contextLSTMSize, bufferContext, stackContext, lstmOptions, unknownValueThreshold)); + contextLSTM->setFirstInputIndex(currentInputSize); + currentOutputSize += contextLSTM->getOutputSize(); + currentInputSize += contextLSTM->getInputSize(); + + if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0) { - rawInputLSTM = register_module("rawInputLSTM", LSTM(embeddingsSize, rawInputLSTMSize, lstmOptionsAll)); - rawInputLSTMOutSize = rawInputLSTM->getOutputSize(rawInputSize); + hasRawInputLSTM = true; + rawInputLSTM = register_module("rawInputLSTM", RawInputLSTM(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputLSTMSize, lstmOptionsAll)); + rawInputLSTM->setFirstInputIndex(currentInputSize); + currentOutputSize += rawInputLSTM->getOutputSize(); + currentInputSize += rawInputLSTM->getInputSize(); } - wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); - embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3)); - hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3)); - contextLSTM = register_module("contextLSTM", LSTM(columns.size()*embeddingsSize, contextLSTMSize, lstmOptions)); - splitTransLSTM = register_module("splitTransLSTM", LSTM(embeddingsSize, embeddingsSize, lstmOptionsAll)); - - int totalLSTMOutputSize = rawInputLSTMOutSize + contextLSTM->getOutputSize(getContextSize()) + splitTransLSTM->getOutputSize(Config::maxNbAppliableSplitTransitions); + splitTransLSTM = register_module("splitTransLSTM", SplitTransLSTM(Config::maxNbAppliableSplitTransitions, embeddingsSize, embeddingsSize, lstmOptionsAll)); + splitTransLSTM->setFirstInputIndex(currentInputSize); + currentOutputSize += splitTransLSTM->getOutputSize(); + currentInputSize += splitTransLSTM->getInputSize(); for (unsigned int i = 0; i < focusedColumns.size(); i++) { - lstms.emplace_back(register_module(fmt::format("LSTM_{}", focusedColumns[i]), LSTM(embeddingsSize, focusedLSTMSize, lstmOptions))); - totalLSTMOutputSize += (bufferFocused.size()+stackFocused.size())*lstms.back()->getOutputSize(maxNbElements[i]); + focusedLstms.emplace_back(register_module(fmt::format("LSTM_{}", focusedColumns[i]), FocusedColumnLSTM(focusedBufferIndexes, focusedStackIndexes, focusedColumns[i], maxNbElements[i], embeddingsSize, focusedLSTMSize, lstmOptions))); + focusedLstms.back()->setFirstInputIndex(currentInputSize); + currentOutputSize += focusedLstms.back()->getOutputSize(); + currentInputSize += focusedLstms.back()->getInputSize(); } - linear1 = register_module("linear1", torch::nn::Linear(embeddingsSize+totalLSTMOutputSize, hiddenSize)); + wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); + embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3)); + hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3)); + + linear1 = register_module("linear1", torch::nn::Linear(currentOutputSize, hiddenSize)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); } @@ -53,40 +57,19 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input) auto embeddings = embeddingsDropout(wordEmbeddings(input)); - auto state = embeddings.narrow(1, 0, 1).squeeze(1); - - auto splitTrans = embeddings.narrow(1, 1, Config::maxNbAppliableSplitTransitions); + std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)}; - auto context = embeddings.narrow(1, 1+splitTrans.size(1)+rawInputSize, getContextSize()); + outputs.emplace_back(contextLSTM(embeddings)); - context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}); + if (hasRawInputLSTM) + outputs.emplace_back(rawInputLSTM(embeddings)); - auto elementsEmbeddings = embeddings.narrow(1, 1+splitTrans.size(1)+rawInputSize+context.size(1), input.size(1)-(1+splitTrans.size(1)+rawInputSize+context.size(1))); + outputs.emplace_back(splitTransLSTM(embeddings)); - std::vector<torch::Tensor> lstmOutputs; + for (auto & lstm : focusedLstms) + outputs.emplace_back(lstm(embeddings)); - lstmOutputs.emplace_back(state); - - if (rawInputSize != 0) - { - auto rawLetters = embeddings.narrow(1, splitTrans.size(1), rawInputSize); - lstmOutputs.emplace_back(rawInputLSTM(rawLetters)); - } - - lstmOutputs.emplace_back(splitTransLSTM(splitTrans)); - - auto curIndex = 0; - for (unsigned int i = 0; i < focusedColumns.size(); i++) - for (unsigned int focused = 0; focused < bufferFocused.size()+stackFocused.size(); focused++) - { - auto lstmInput = elementsEmbeddings.narrow(1, curIndex, maxNbElements[i]); - curIndex += maxNbElements[i]; - lstmOutputs.emplace_back(lstms[i](lstmInput)); - } - - lstmOutputs.emplace_back(contextLSTM(context)); - - auto totalInput = torch::cat(lstmOutputs, 1); + auto totalInput = torch::cat(outputs, 1); return linear2(hiddenDropout(torch::relu(linear1(totalInput)))); } @@ -101,13 +84,12 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, context.back().emplace_back(dict.getIndexOrInsert(config.getState())); - addAppliableSplitTransitions(context, dict, config); - - addRawInput(context, dict, config, leftWindowRawInput, rightWindowRawInput); - - addContext(context, dict, config, extractContextIndexes(config), unknownValueThreshold, {"FORM","LEMMA"}); - - addFocused(context, dict, config, extractFocusedIndexes(config), focusedColumns, maxNbElements); + contextLSTM->addToContext(context, dict, config); + if (hasRawInputLSTM) + rawInputLSTM->addToContext(context, dict, config); + splitTransLSTM->addToContext(context, dict, config); + for (auto & lstm : focusedLstms) + lstm->addToContext(context, dict, config); if (!is_training() && context.size() > 1) util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size())); diff --git a/torch_modules/src/MLP.cpp b/torch_modules/src/MLP.cpp deleted file mode 100644 index 182e880..0000000 --- a/torch_modules/src/MLP.cpp +++ /dev/null @@ -1,8 +0,0 @@ -#include "MLP.hpp" -#include <regex> - -MLPImpl::MLPImpl(const std::string & topology) -{ - -} - diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index 4a123b2..02e8a19 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -1,195 +1,4 @@ #include "NeuralNetwork.hpp" -#include "Transition.hpp" torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU); -std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config) const -{ - std::vector<long> context; - - for (int index : bufferContext) - context.emplace_back(config.getRelativeWordIndex(index)); - - for (int index : stackContext) - if (config.hasStack(index)) - context.emplace_back(config.getStack(index)); - else - context.emplace_back(-1); - - return context; -} - -std::vector<long> NeuralNetworkImpl::extractFocusedIndexes(const Config & config) const -{ - std::vector<long> context; - - for (int index : bufferFocused) - context.emplace_back(config.getRelativeWordIndex(index)); - - for (int index : stackFocused) - if (config.hasStack(index)) - context.emplace_back(config.getStack(index)); - else - context.emplace_back(-1); - - return context; -} - -std::vector<std::vector<long>> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const -{ - std::vector<long> indexes = extractContextIndexes(config); - std::vector<long> context; - - for (auto & col : columns) - for (auto index : indexes) - if (index == -1) - context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); - else - context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, index))); - - return {context}; -} - -int NeuralNetworkImpl::getContextSize() const -{ - return columns.size()*(bufferContext.size()+stackContext.size()); -} - -void NeuralNetworkImpl::setBufferContext(const std::vector<int> & bufferContext) -{ - this->bufferContext = bufferContext; -} - -void NeuralNetworkImpl::setStackContext(const std::vector<int> & stackContext) -{ - this->stackContext = stackContext; -} - -void NeuralNetworkImpl::setBufferFocused(const std::vector<int> & bufferFocused) -{ - this->bufferFocused = bufferFocused; -} - -void NeuralNetworkImpl::setStackFocused(const std::vector<int> & stackFocused) -{ - this->stackFocused = stackFocused; -} - -void NeuralNetworkImpl::setColumns(const std::vector<std::string> & columns) -{ - this->columns = columns; -} - -void NeuralNetworkImpl::addAppliableSplitTransitions(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const -{ - auto & splitTransitions = config.getAppliableSplitTransitions(); - for (int i = 0; i < Config::maxNbAppliableSplitTransitions; i++) - if (i < (int)splitTransitions.size()) - context.back().emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName())); - else - context.back().emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); -} - -void NeuralNetworkImpl::addRawInput(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, int leftWindowRawInput, int rightWindowRawInput) const -{ - if (leftWindowRawInput < 0 or rightWindowRawInput < 0) - return; - - for (int i = 0; i < leftWindowRawInput; i++) - if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i)) - context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); - else - context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); - - for (int i = 0; i <= rightWindowRawInput; i++) - if (config.hasCharacter(config.getCharacterIndex()+i)) - context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)))); - else - context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); -} - -void NeuralNetworkImpl::addContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & contextIndexes, int unknownValueThreshold, const std::vector<std::string> & unknownValueColumns) const -{ - for (auto index : contextIndexes) - for (auto & col : columns) - if (index == -1) - for (auto & contextElement : context) - contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); - else - { - int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index)); - - for (auto & contextElement : context) - contextElement.push_back(dictIndex); - - if (is_training()) - for (auto & targetCol : unknownValueColumns) - if (col == targetCol) - if (dict.getNbOccs(dictIndex) <= unknownValueThreshold) - { - context.emplace_back(context.back()); - context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr); - } - } -} - -void NeuralNetworkImpl::addFocused(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, const std::vector<long> & focusedIndexes, const std::vector<std::string> & focusedColumns, const std::vector<int> & maxNbElements) const -{ - for (auto & contextElement : context) - for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++) - { - auto & col = focusedColumns[colIndex]; - - for (auto index : focusedIndexes) - { - if (index == -1) - { - for (int i = 0; i < maxNbElements[colIndex]; i++) - contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); - continue; - } - - std::vector<std::string> elements; - if (col == "FORM") - { - auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get()); - - for (int i = 0; i < maxNbElements[colIndex]; i++) - if (i < (int)asUtf8.size()) - elements.emplace_back(fmt::format("{}", asUtf8[i])); - else - elements.emplace_back(Dict::nullValueStr); - } - else if (col == "FEATS") - { - auto splited = util::split(config.getAsFeature(col, index).get(), '|'); - - for (int i = 0; i < maxNbElements[colIndex]; i++) - if (i < (int)splited.size()) - elements.emplace_back(fmt::format("FEATS({})", splited[i])); - else - elements.emplace_back(Dict::nullValueStr); - } - else if (col == "ID") - { - if (config.isTokenPredicted(index)) - elements.emplace_back("ID(TOKEN)"); - else if (config.isMultiwordPredicted(index)) - elements.emplace_back("ID(MULTIWORD)"); - else if (config.isEmptyNodePredicted(index)) - elements.emplace_back("ID(EMPTYNODE)"); - } - else - { - elements.emplace_back(config.getAsFeature(col, index)); - } - - if ((int)elements.size() != maxNbElements[colIndex]) - util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col)); - - for (auto & element : elements) - contextElement.emplace_back(dict.getIndexOrInsert(element)); - } - } -} - diff --git a/torch_modules/src/RLTNetwork.cpp b/torch_modules/src/RLTNetwork.cpp deleted file mode 100644 index e4f3fc2..0000000 --- a/torch_modules/src/RLTNetwork.cpp +++ /dev/null @@ -1,190 +0,0 @@ -#include "RLTNetwork.hpp" - -RLTNetworkImpl::RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements) -{ - constexpr int embeddingsSize = 30; - constexpr int lstmOutputSize = 128; - constexpr int treeEmbeddingsSize = 256; - constexpr int hiddenSize = 500; - - //TODO gerer ces context - this->leftBorder = leftBorder; - this->rightBorder = rightBorder; - setBufferContext({}); - setStackContext({}); - setColumns({"FORM", "UPOS"}); - - wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); - linear1 = register_module("linear1", torch::nn::Linear(treeEmbeddingsSize*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize)); - linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); - vectorBiLSTM = register_module("vector_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize*columns.size(), lstmOutputSize).batch_first(true).bidirectional(true))); - treeLSTM = register_module("tree_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(treeEmbeddingsSize+2*lstmOutputSize, treeEmbeddingsSize).batch_first(true).bidirectional(false))); - S = register_parameter("S", torch::randn(treeEmbeddingsSize)); - nullTree = register_parameter("null_tree", torch::randn(treeEmbeddingsSize)); -} - -torch::Tensor RLTNetworkImpl::forward(torch::Tensor input) -{ - if (input.dim() == 1) - input = input.unsqueeze(0); - - auto focusedIndexes = input.narrow(1, 0, focusedBufferIndexes.size()+focusedStackIndexes.size()); - auto computeOrder = input.narrow(1, focusedIndexes.size(1), getContextSize()/columns.size()); - auto childsFlat = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1), maxNbChilds*(getContextSize()/columns.size())); - auto childs = torch::reshape(childsFlat, {childsFlat.size(0), computeOrder.size(1), maxNbChilds}); - auto wordIndexes = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1)+childsFlat.size(1), getContextSize()); - auto baseEmbeddings = wordEmbeddings(wordIndexes); - auto concatBaseEmbeddings = torch::reshape(baseEmbeddings, {baseEmbeddings.size(0), (int)baseEmbeddings.size(1)/(int)columns.size(), (int)baseEmbeddings.size(2)*(int)columns.size()}); - auto vectorRepresentations = vectorBiLSTM(concatBaseEmbeddings).output; - - std::vector<std::map<int, torch::Tensor>> treeRepresentations; - for (unsigned int batch = 0; batch < computeOrder.size(0); batch++) - { - treeRepresentations.emplace_back(); - for (unsigned int i = 0; i < computeOrder[batch].size(0); i++) - { - int index = computeOrder[batch][i].item<int>(); - if (index == -1) - break; - std::vector<torch::Tensor> inputVector; - inputVector.emplace_back(torch::cat({vectorRepresentations[batch][index], S}, 0)); - for (unsigned int childIndex = 0; childIndex < maxNbChilds; childIndex++) - { - int child = childs[batch][index][childIndex].item<int>(); - if (child == -1) - break; - inputVector.emplace_back(torch::cat({vectorRepresentations[batch][index], treeRepresentations[batch].count(child) ? treeRepresentations[batch][child] : nullTree}, 0)); - } - auto lstmInput = torch::stack(inputVector, 0).unsqueeze(0); - auto lstmOut = treeLSTM(lstmInput).output.permute({1,0,2})[-1].squeeze(); - treeRepresentations[batch][index] = lstmOut; - } - } - - std::vector<torch::Tensor> focusedTrees; - std::vector<torch::Tensor> representations; - for (unsigned int batch = 0; batch < focusedIndexes.size(0); batch++) - { - focusedTrees.clear(); - for (unsigned int i = 0; i < focusedIndexes[batch].size(0); i++) - { - int index = focusedIndexes[batch][i].item<int>(); - if (index == -1) - focusedTrees.emplace_back(nullTree); - else - focusedTrees.emplace_back(treeRepresentations[batch].count(index) ? treeRepresentations[batch][index] : nullTree); - } - representations.emplace_back(torch::cat(focusedTrees, 0).unsqueeze(0)); - } - - auto representation = torch::cat(representations, 0); - return linear2(torch::relu(linear1(representation))); -} - -std::vector<std::vector<long>> RLTNetworkImpl::extractContext(Config & config, Dict & dict) const -{ - std::vector<long> contextIndexes; - std::stack<int> leftContext; - for (int index = config.getWordIndex()-1; config.has(0,index,0) && (int)leftContext.size() < leftBorder; --index) - if (config.isToken(index)) - leftContext.push(index); - - while ((int)contextIndexes.size() < leftBorder-(int)leftContext.size()) - contextIndexes.emplace_back(-1); - while (!leftContext.empty()) - { - contextIndexes.emplace_back(leftContext.top()); - leftContext.pop(); - } - - for (int index = config.getWordIndex(); config.has(0,index,0) && (int)contextIndexes.size() < leftBorder+rightBorder+1; ++index) - if (config.isToken(index)) - contextIndexes.emplace_back(index); - - while ((int)contextIndexes.size() < leftBorder+rightBorder+1) - contextIndexes.emplace_back(-1); - - std::map<long, long> indexInContext; - for (auto & l : contextIndexes) - indexInContext.emplace(std::make_pair(l, indexInContext.size())); - - std::vector<long> headOf; - for (auto & l : contextIndexes) - { - if (l == -1) - headOf.push_back(-1); - else - { - auto & head = config.getAsFeature(Config::headColName, l); - if (util::isEmpty(head) or head == "_") - headOf.push_back(-1); - else if (indexInContext.count(std::stoi(head))) - headOf.push_back(std::stoi(head)); - else - headOf.push_back(-1); - } - } - - std::vector<std::vector<long>> childs(headOf.size()); - for (unsigned int i = 0; i < headOf.size(); i++) - if (headOf[i] != -1) - childs[indexInContext[headOf[i]]].push_back(contextIndexes[i]); - - std::vector<long> treeComputationOrder; - std::vector<bool> treeIsComputed(contextIndexes.size(), false); - - std::function<void(long)> depthFirst; - depthFirst = [&config, &depthFirst, &indexInContext, &treeComputationOrder, &treeIsComputed, &childs](long root) - { - if (!indexInContext.count(root)) - return; - - if (treeIsComputed[indexInContext[root]]) - return; - - for (auto child : childs[indexInContext[root]]) - depthFirst(child); - - treeIsComputed[indexInContext[root]] = true; - treeComputationOrder.push_back(indexInContext[root]); - }; - - for (auto & l : focusedBufferIndexes) - if (contextIndexes[leftBorder+l] != -1) - depthFirst(contextIndexes[leftBorder+l]); - - for (auto & l : focusedStackIndexes) - if (config.hasStack(l)) - depthFirst(config.getStack(l)); - - std::vector<long> context; - - for (auto & c : focusedBufferIndexes) - context.push_back(leftBorder+c); - for (auto & c : focusedStackIndexes) - if (config.hasStack(c) && indexInContext.count(config.getStack(c))) - context.push_back(indexInContext[config.getStack(c)]); - else - context.push_back(-1); - for (auto & c : treeComputationOrder) - context.push_back(c); - while (context.size() < contextIndexes.size()+focusedBufferIndexes.size()+focusedStackIndexes.size()) - context.push_back(-1); - for (auto & c : childs) - { - for (unsigned int i = 0; i < maxNbChilds; i++) - if (i < c.size()) - context.push_back(indexInContext[c[i]]); - else - context.push_back(-1); - } - for (auto & l : contextIndexes) - for (auto & col : columns) - if (l == -1) - context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); - else - context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, l))); - - return {context}; -} - diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp index 8e973e4..9730dd2 100644 --- a/torch_modules/src/RandomNetwork.cpp +++ b/torch_modules/src/RandomNetwork.cpp @@ -2,11 +2,6 @@ RandomNetworkImpl::RandomNetworkImpl(long outputSize) : outputSize(outputSize) { - setBufferContext({0}); - setStackContext({}); - setBufferFocused({}); - setStackFocused({}); - setColumns({"FORM"}); } torch::Tensor RandomNetworkImpl::forward(torch::Tensor input) @@ -17,3 +12,8 @@ torch::Tensor RandomNetworkImpl::forward(torch::Tensor input) return torch::randn({input.size(0), outputSize}, torch::TensorOptions().device(device).requires_grad(true)); } +std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &, Dict &) const +{ + return std::vector<std::vector<long>>(); +} + diff --git a/torch_modules/src/RawInputLSTM.cpp b/torch_modules/src/RawInputLSTM.cpp new file mode 100644 index 0000000..ebcfbfd --- /dev/null +++ b/torch_modules/src/RawInputLSTM.cpp @@ -0,0 +1,40 @@ +#include "RawInputLSTM.hpp" + +RawInputLSTMImpl::RawInputLSTMImpl(int leftWindow, int rightWindow, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : leftWindow(leftWindow), rightWindow(rightWindow) +{ + lstm = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options)); +} + +torch::Tensor RawInputLSTMImpl::forward(torch::Tensor input) +{ + return lstm(input.narrow(1, firstInputIndex, getInputSize())); +} + +std::size_t RawInputLSTMImpl::getOutputSize() +{ + return lstm->getOutputSize(leftWindow + rightWindow + 1); +} + +std::size_t RawInputLSTMImpl::getInputSize() +{ + return leftWindow + rightWindow + 1; +} + +void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +{ + if (leftWindow < 0 or rightWindow < 0) + return; + + for (int i = 0; i < leftWindow; i++) + if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i)) + context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i)))); + else + context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); + + for (int i = 0; i <= rightWindow; i++) + if (config.hasCharacter(config.getCharacterIndex()+i)) + context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)))); + else + context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr)); +} + diff --git a/torch_modules/src/SplitTransLSTM.cpp b/torch_modules/src/SplitTransLSTM.cpp new file mode 100644 index 0000000..283358c --- /dev/null +++ b/torch_modules/src/SplitTransLSTM.cpp @@ -0,0 +1,33 @@ +#include "SplitTransLSTM.hpp" +#include "Transition.hpp" + +SplitTransLSTMImpl::SplitTransLSTMImpl(int maxNbTrans, int embeddingsSize, int outEmbeddingsSize, LSTMImpl::LSTMOptions options) : maxNbTrans(maxNbTrans) +{ + lstm = register_module("lstm", LSTM(embeddingsSize, outEmbeddingsSize, options)); +} + +torch::Tensor SplitTransLSTMImpl::forward(torch::Tensor input) +{ + return lstm(input.narrow(1, firstInputIndex, getInputSize())); +} + +std::size_t SplitTransLSTMImpl::getOutputSize() +{ + return lstm->getOutputSize(maxNbTrans); +} + +std::size_t SplitTransLSTMImpl::getInputSize() +{ + return maxNbTrans; +} + +void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const +{ + auto & splitTransitions = config.getAppliableSplitTransitions(); + for (int i = 0; i < maxNbTrans; i++) + if (i < (int)splitTransitions.size()) + context.back().emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName())); + else + context.back().emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); +} + diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp new file mode 100644 index 0000000..2af75a3 --- /dev/null +++ b/torch_modules/src/Submodule.cpp @@ -0,0 +1,7 @@ +#include "Submodule.hpp" + +void Submodule::setFirstInputIndex(std::size_t firstInputIndex) +{ + this->firstInputIndex = firstInputIndex; +} + -- GitLab