From e4e9d906b38449c2d0825d3fb07b4b2c603fda3a Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 17 Apr 2020 11:16:33 +0200 Subject: [PATCH] Added option to chose if embeddings dropout is 2d or not --- reading_machine/src/Classifier.cpp | 11 +++++++++-- torch_modules/include/LSTMNetwork.hpp | 6 ++---- torch_modules/src/LSTMNetwork.cpp | 23 ++++++++++++++--------- 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 89c7bb3..b3e0710 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -95,7 +95,7 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size std::vector<std::pair<int, float>> mlp; int rawInputLeftWindow, rawInputRightWindow; int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, treeEmbeddingSize; - bool bilstm; + bool bilstm, drop2d; float lstmDropout, embeddingsDropout, totalInputDropout; if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Unknown value threshold :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&unknownValueThreshold](auto sm) @@ -254,6 +254,13 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size })) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Embeddings dropout :) value")); + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Dropout 2d :|)(?:(?:\\s|\\t)*)(true|false)"), definition[curIndex], [&curIndex,&drop2d](auto sm) + { + drop2d = sm.str(1) == "true"; + curIndex++; + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Dropout 2d :) true|false")); + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Tree embedding columns :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [&curIndex,&treeEmbeddingColumns](auto sm) { treeEmbeddingColumns = util::split(sm.str(1), ' '); @@ -292,7 +299,7 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size })) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding size :) value")); - this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout)); + this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout, drop2d)); } void Classifier::loadOptimizer(std::filesystem::path path) diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp index 550fae1..e742b97 100644 --- a/torch_modules/include/LSTMNetwork.hpp +++ b/torch_modules/include/LSTMNetwork.hpp @@ -14,6 +14,7 @@ class LSTMNetworkImpl : public NeuralNetworkImpl private : torch::nn::Embedding wordEmbeddings{nullptr}; + torch::nn::Dropout2d embeddingsDropout2d{nullptr}; torch::nn::Dropout embeddingsDropout{nullptr}; torch::nn::Dropout inputDropout{nullptr}; @@ -24,12 +25,9 @@ class LSTMNetworkImpl : public NeuralNetworkImpl DepthLayerTreeEmbedding treeEmbedding{nullptr}; std::vector<FocusedColumnLSTM> focusedLstms; - bool hasRawInputLSTM{false}; - bool hasTreeEmbedding{false}; - public : - LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout); + LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d); torch::Tensor forward(torch::Tensor input) override; std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override; }; diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp index 98cbb08..cfa004e 100644 --- a/torch_modules/src/LSTMNetwork.cpp +++ b/torch_modules/src/LSTMNetwork.cpp @@ -1,6 +1,6 @@ #include "LSTMNetwork.hpp" -LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout) +LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout, bool drop2d) { LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false}; auto lstmOptionsAll = lstmOptions; @@ -16,7 +16,6 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std:: if (leftWindowRawInput >= 0 and rightWindowRawInput >= 0) { - hasRawInputLSTM = true; rawInputLSTM = register_module("rawInputLSTM", RawInputLSTM(leftWindowRawInput, rightWindowRawInput, embeddingsSize, rawInputLSTMSize, lstmOptionsAll)); rawInputLSTM->setFirstInputIndex(currentInputSize); currentOutputSize += rawInputLSTM->getOutputSize(); @@ -25,7 +24,6 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std:: if (!treeEmbeddingColumns.empty()) { - hasTreeEmbedding = true; treeEmbedding = register_module("treeEmbedding", DepthLayerTreeEmbedding(treeEmbeddingNbElems,embeddingsSize,treeEmbeddingSize,treeEmbeddingColumns,treeEmbeddingBuffer,treeEmbeddingStack,lstmOptions)); treeEmbedding->setFirstInputIndex(currentInputSize); currentOutputSize += treeEmbedding->getOutputSize(); @@ -46,7 +44,10 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std:: } wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); - embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue)); + if (drop2d) + embeddingsDropout2d = register_module("embeddings_dropout2d", torch::nn::Dropout2d(embeddingsDropoutValue)); + else + embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue)); inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout)); mlp = register_module("mlp", MLP(currentOutputSize, nbOutputs, mlpParams)); @@ -57,16 +58,20 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input) if (input.dim() == 1) input = input.unsqueeze(0); - auto embeddings = embeddingsDropout(wordEmbeddings(input)); + auto embeddings = wordEmbeddings(input); + if (embeddingsDropout2d.is_empty()) + embeddings = embeddingsDropout(embeddings); + else + embeddings = embeddingsDropout2d(embeddings); std::vector<torch::Tensor> outputs{embeddings.narrow(1,0,1).squeeze(1)}; outputs.emplace_back(contextLSTM(embeddings)); - if (hasRawInputLSTM) + if (!rawInputLSTM.is_empty()) outputs.emplace_back(rawInputLSTM(embeddings)); - if (hasTreeEmbedding) + if (!treeEmbedding.is_empty()) outputs.emplace_back(treeEmbedding(embeddings)); outputs.emplace_back(splitTransLSTM(embeddings)); @@ -91,10 +96,10 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, contextLSTM->addToContext(context, dict, config); - if (hasRawInputLSTM) + if (!rawInputLSTM.is_empty()) rawInputLSTM->addToContext(context, dict, config); - if (hasTreeEmbedding) + if (!treeEmbedding.is_empty()) treeEmbedding->addToContext(context, dict, config); splitTransLSTM->addToContext(context, dict, config); -- GitLab