From a4af59b669aa37dc09b5e5999dcb1760ce14ff25 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 20 Mar 2020 16:00:48 +0100
Subject: [PATCH] Removed permute from LSTM

---
 reading_machine/src/Classifier.cpp |  4 ++--
 torch_modules/src/LSTMNetwork.cpp  | 33 +++++++++++++++---------------
 2 files changed, 18 insertions(+), 19 deletions(-)

diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index 03a4f5d..7c96c7f 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -49,7 +49,7 @@ void Classifier::initNeuralNetwork(const std::string & topology)
       }
     },
     {
-      std::regex("CNN\\((\\d+),(\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
+      std::regex("CNN\\(([+\\-]?\\d+),(\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
       "CNN(unknownValueThreshold,leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
       [this,topology](auto sm)
       {
@@ -71,7 +71,7 @@ void Classifier::initNeuralNetwork(const std::string & topology)
       }
     },
     {
-      std::regex("LSTM\\((\\d+),(\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
+      std::regex("LSTM\\(([+\\-]?\\d+),(\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
       "LSTM(unknownValueThreshold,leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
       [this,topology](auto sm)
       {
diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp
index 734770e..84a36f2 100644
--- a/torch_modules/src/LSTMNetwork.cpp
+++ b/torch_modules/src/LSTMNetwork.cpp
@@ -6,6 +6,7 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, int l
   constexpr int hiddenSize = 1024;
   constexpr int contextLSTMSize = 512;
   constexpr int focusedLSTMSize = 64;
+  constexpr int rawInputLSTMSize = 16;
 
   setLeftBorder(leftBorder);
   setRightBorder(rightBorder);
@@ -16,21 +17,23 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, int l
   if (leftWindowRawInput < 0 or rightWindowRawInput < 0)
     rawInputSize = 0;
   else
-    rawInputLSTM = register_module("rawInputLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, focusedLSTMSize).batch_first(false).bidirectional(true)));
+    rawInputLSTM = register_module("rawInputLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, rawInputLSTMSize).batch_first(true).bidirectional(true)));
 
-  int rawInputLSTMOutputSize = rawInputSize == 0 ? 0 : (rawInputLSTM->options.hidden_size() * (rawInputLSTM->options.bidirectional() ? 4 : 1));
+  int rawInputLSTMOutputSize = 0;
+  if (rawInputSize > 0)
+    rawInputLSTMOutputSize = (rawInputSize * rawInputLSTM->options.hidden_size() * (rawInputLSTM->options.bidirectional() ? 2 : 1));
 
   wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
   embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
   lstmDropout = register_module("lstm_dropout", torch::nn::Dropout(0.3));
   hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
-  contextLSTM = register_module("contextLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(columns.size()*embeddingsSize, contextLSTMSize).batch_first(false).bidirectional(true)));
+  contextLSTM = register_module("contextLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(columns.size()*embeddingsSize, contextLSTMSize).batch_first(true).bidirectional(true)));
 
   int totalLSTMOutputSize = contextLSTM->options.hidden_size() * (contextLSTM->options.bidirectional() ? 4 : 1) + rawInputLSTMOutputSize;
 
   for (auto & col : focusedColumns)
   {
-    lstms.emplace_back(register_module(fmt::format("LSTM_{}", col), torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, focusedLSTMSize).batch_first(false).bidirectional(true))));
+    lstms.emplace_back(register_module(fmt::format("LSTM_{}", col), torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, focusedLSTMSize).batch_first(true).bidirectional(true))));
     totalLSTMOutputSize += lstms.back()->options.hidden_size() * (lstms.back()->options.bidirectional() ? 4 : 1) * (focusedBufferIndexes.size()+focusedStackIndexes.size());
   }
 
@@ -46,22 +49,18 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
   auto embeddings = embeddingsDropout(wordEmbeddings(input));
 
   auto context = embeddings.narrow(1, rawInputSize, columns.size()*(1+leftBorder+rightBorder));
+
   context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()});
 
   auto elementsEmbeddings = embeddings.narrow(1, rawInputSize+context.size(1), input.size(1)-(rawInputSize+context.size(1)));
 
-  context = context.permute({1,0,2});
-
   std::vector<torch::Tensor> lstmOutputs;
 
   if (rawInputSize != 0)
   {
-    auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1).permute({1,0});
+    auto rawLetters = embeddings.narrow(1, 0, rawInputSize);
     auto lstmOut = rawInputLSTM(rawLetters).output;
-    if (rawInputLSTM->options.bidirectional())
-      lstmOutputs.emplace_back(torch::cat({lstmOut[0],lstmOut[-1]}, 1));
-    else
-      lstmOutputs.emplace_back(lstmOut[-1]);
+    lstmOutputs.emplace_back(lstmOut.reshape({lstmOut.size(0), -1}));
   }
 
   auto curIndex = 0;
@@ -70,22 +69,22 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
     long nbElements = maxNbElements[i];
     for (unsigned int focused = 0; focused < focusedBufferIndexes.size()+focusedStackIndexes.size(); focused++)
     {
-      auto lstmInput = elementsEmbeddings.narrow(1, curIndex, nbElements).permute({1,0,2});
+      auto lstmInput = elementsEmbeddings.narrow(1, curIndex, nbElements);
       curIndex += nbElements;
       auto lstmOut = lstms[i](lstmInput).output;
 
       if (lstms[i]->options.bidirectional())
-        lstmOutputs.emplace_back(torch::cat({lstmOut[0],lstmOut[-1]}, 1));
+        lstmOutputs.emplace_back(torch::cat({lstmOut.narrow(1, 0, 1).squeeze(1),lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1)}, 1));
       else
-        lstmOutputs.emplace_back(lstmOut[-1]);
+        lstmOutputs.emplace_back(lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1));
     }
   }
 
   auto lstmOut = contextLSTM(context).output;
   if (contextLSTM->options.bidirectional())
-    lstmOutputs.emplace_back(torch::cat({lstmOut[0],lstmOut[-1]}, 1));
-  else
-    lstmOutputs.emplace_back(lstmOut[-1]);
+    lstmOutputs.emplace_back(torch::cat({lstmOut.narrow(1, 0, 1).squeeze(1),lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1)}, 1));
+   else
+     lstmOutputs.emplace_back(lstmOut.narrow(1, lstmOut.size(1)-1, 1).squeeze(1));
 
   auto totalInput = lstmDropout(torch::cat(lstmOutputs, 1));
 
-- 
GitLab