From 2261c98b655864e2f475ce65e6cd2542f5d1aeef Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 15 Apr 2020 20:02:36 +0200
Subject: [PATCH] Fixed bug in RawInputLSTM where only last context was added
 to. Applying dropout after relu in MLP

---
 torch_modules/include/LSTMNetwork.hpp |  1 -
 torch_modules/src/MLP.cpp             |  2 +-
 torch_modules/src/RawInputLSTM.cpp    | 25 ++++++++++++++-----------
 3 files changed, 15 insertions(+), 13 deletions(-)

diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp
index f0b58dc..7a6576c 100644
--- a/torch_modules/include/LSTMNetwork.hpp
+++ b/torch_modules/include/LSTMNetwork.hpp
@@ -15,7 +15,6 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
 
   torch::nn::Embedding wordEmbeddings{nullptr};
   torch::nn::Dropout embeddingsDropout{nullptr};
-  torch::nn::Dropout lstmDropout{nullptr};
 
   MLP mlp{nullptr};
   ContextLSTM contextLSTM{nullptr};
diff --git a/torch_modules/src/MLP.cpp b/torch_modules/src/MLP.cpp
index 52a4c23..b886245 100644
--- a/torch_modules/src/MLP.cpp
+++ b/torch_modules/src/MLP.cpp
@@ -19,7 +19,7 @@ torch::Tensor MLPImpl::forward(torch::Tensor input)
 {
   torch::Tensor output = input;
   for (unsigned int i = 0; i < layers.size()-1; i++)
-    output = torch::relu(dropouts[i](layers[i](output)));
+    output = dropouts[i](torch::relu(layers[i](output)));
 
   return layers.back()(output);
 }
diff --git a/torch_modules/src/RawInputLSTM.cpp b/torch_modules/src/RawInputLSTM.cpp
index ebcfbfd..2aa8cfd 100644
--- a/torch_modules/src/RawInputLSTM.cpp
+++ b/torch_modules/src/RawInputLSTM.cpp
@@ -25,16 +25,19 @@ void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Di
   if (leftWindow < 0 or rightWindow < 0)
     return;
 
-  for (int i = 0; i < leftWindow; i++)
-    if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i))
-      context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i))));
-    else
-      context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
-
-  for (int i = 0; i <= rightWindow; i++)
-    if (config.hasCharacter(config.getCharacterIndex()+i))
-      context.back().push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
-    else
-      context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+  for (auto & contextElement : context)
+  {
+    for (int i = 0; i < leftWindow; i++)
+      if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i))
+        contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i))));
+      else
+        contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+
+    for (int i = 0; i <= rightWindow; i++)
+      if (config.hasCharacter(config.getCharacterIndex()+i))
+        contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i))));
+      else
+        contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+  }
 }
 
-- 
GitLab