From 70896a64b530e3ffcacd39f1f94baf0da0153b31 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 17 Apr 2020 10:32:30 +0200
Subject: [PATCH] Added total input dropout to LSTMNetwork

---
 reading_machine/src/Classifier.cpp    | 11 +++++++++--
 torch_modules/include/LSTMNetwork.hpp |  3 ++-
 torch_modules/src/LSTMNetwork.cpp     |  5 +++--
 3 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index 957e21f..89c7bb3 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -96,7 +96,7 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
   int rawInputLeftWindow, rawInputRightWindow;
   int embeddingsSize, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, treeEmbeddingSize;
   bool bilstm;
-  float lstmDropout, embeddingsDropout;
+  float lstmDropout, embeddingsDropout, totalInputDropout;
 
   if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Unknown value threshold :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&unknownValueThreshold](auto sm)
         {
@@ -240,6 +240,13 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
         }))
     util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(LSTM dropout :) value"));
 
+  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Total input dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&totalInputDropout](auto sm)
+        {
+          totalInputDropout = std::stof(sm.str(1));
+          curIndex++;
+        }))
+    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Total input dropout :) value"));
+
   if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Embeddings dropout :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&embeddingsDropout](auto sm)
         {
           embeddingsDropout = std::stof(sm.str(1));
@@ -285,7 +292,7 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size
         }))
     util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Tree embedding size :) value"));
 
-  this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout));
+  this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout, totalInputDropout));
 }
 
 void Classifier::loadOptimizer(std::filesystem::path path)
diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp
index 4ff3b17..550fae1 100644
--- a/torch_modules/include/LSTMNetwork.hpp
+++ b/torch_modules/include/LSTMNetwork.hpp
@@ -15,6 +15,7 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
 
   torch::nn::Embedding wordEmbeddings{nullptr};
   torch::nn::Dropout embeddingsDropout{nullptr};
+  torch::nn::Dropout inputDropout{nullptr};
 
   MLP mlp{nullptr};
   ContextLSTM contextLSTM{nullptr};
@@ -28,7 +29,7 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
 
   public :
 
-  LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue);
+  LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout);
   torch::Tensor forward(torch::Tensor input) override;
   std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
 };
diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp
index ae8ed04..98cbb08 100644
--- a/torch_modules/src/LSTMNetwork.cpp
+++ b/torch_modules/src/LSTMNetwork.cpp
@@ -1,6 +1,6 @@
 #include "LSTMNetwork.hpp"
 
-LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue)
+LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput, int embeddingsSize, std::vector<std::pair<int, float>> mlpParams, int contextLSTMSize, int focusedLSTMSize, int rawInputLSTMSize, int splitTransLSTMSize, int numLayers, bool bilstm, float lstmDropout, std::vector<std::string> treeEmbeddingColumns, std::vector<int> treeEmbeddingBuffer, std::vector<int> treeEmbeddingStack, std::vector<int> treeEmbeddingNbElems, int treeEmbeddingSize, float embeddingsDropoutValue, float totalInputDropout)
 {
   LSTMImpl::LSTMOptions lstmOptions{true,bilstm,numLayers,lstmDropout,false};
   auto lstmOptionsAll = lstmOptions;
@@ -47,6 +47,7 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::
 
   wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
   embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(embeddingsDropoutValue));
+  inputDropout = register_module("input_dropout", torch::nn::Dropout(totalInputDropout));
 
   mlp = register_module("mlp", MLP(currentOutputSize, nbOutputs, mlpParams));
 }
@@ -73,7 +74,7 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
   for (auto & lstm : focusedLstms)
     outputs.emplace_back(lstm(embeddings));
 
-  auto totalInput = torch::cat(outputs, 1);
+  auto totalInput = inputDropout(torch::cat(outputs, 1));
 
   return mlp(totalInput);
 }
-- 
GitLab