From a0af9039e92afff2f21137e53dcbfe1474584960 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 16 Mar 2020 21:52:25 +0100
Subject: [PATCH] Added LSTMNetwork

---
 reading_machine/src/Classifier.cpp    |  23 +++
 torch_modules/include/LSTMNetwork.hpp |  38 +++++
 torch_modules/src/LSTMNetwork.cpp     | 221 ++++++++++++++++++++++++++
 3 files changed, 282 insertions(+)
 create mode 100644 torch_modules/include/LSTMNetwork.hpp
 create mode 100644 torch_modules/src/LSTMNetwork.cpp

diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index ae638c1..03a4f5d 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -4,6 +4,7 @@
 #include "ConcatWordsNetwork.hpp"
 #include "RLTNetwork.hpp"
 #include "CNNNetwork.hpp"
+#include "LSTMNetwork.hpp"
 
 Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile)
 {
@@ -69,6 +70,28 @@ void Classifier::initNeuralNetwork(const std::string & topology)
         this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), std::stoi(sm[4]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm[10]), std::stoi(sm[11])));
       }
     },
+    {
+      std::regex("LSTM\\((\\d+),(\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
+      "LSTM(unknownValueThreshold,leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
+      [this,topology](auto sm)
+      {
+        std::vector<int> focusedBuffer, focusedStack, maxNbElements;
+        std::vector<std::string> focusedColumns, columns;
+        for (auto s : util::split(std::string(sm[5]), ','))
+          columns.emplace_back(s);
+        for (auto s : util::split(std::string(sm[6]), ','))
+          focusedBuffer.push_back(std::stoi(std::string(s)));
+        for (auto s : util::split(std::string(sm[7]), ','))
+          focusedStack.push_back(std::stoi(std::string(s)));
+        for (auto s : util::split(std::string(sm[8]), ','))
+          focusedColumns.emplace_back(s);
+        for (auto s : util::split(std::string(sm[9]), ','))
+          maxNbElements.push_back(std::stoi(std::string(s)));
+        if (focusedColumns.size() != maxNbElements.size())
+          util::myThrow("focusedColumns.size() != maxNbElements.size()");
+        this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), std::stoi(sm[4]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm[10]), std::stoi(sm[11])));
+      }
+    },
     {
       std::regex("RLT\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
       "RLT(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.",
diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp
new file mode 100644
index 0000000..5fe8de8
--- /dev/null
+++ b/torch_modules/include/LSTMNetwork.hpp
@@ -0,0 +1,38 @@
+#ifndef LSTMNETWORK__H
+#define LSTMNETWORK__H
+
+#include "NeuralNetwork.hpp"
+
+class LSTMNetworkImpl : public NeuralNetworkImpl
+{
+  private :
+
+  static constexpr int maxNbEmbeddings = 50000;
+
+  int unknownValueThreshold;
+  std::vector<int> focusedBufferIndexes;
+  std::vector<int> focusedStackIndexes;
+  std::vector<std::string> focusedColumns;
+  std::vector<int> maxNbElements;
+  int leftWindowRawInput;
+  int rightWindowRawInput;
+  int rawInputSize;
+
+  torch::nn::Embedding wordEmbeddings{nullptr};
+  torch::nn::Dropout embeddingsDropout{nullptr};
+  torch::nn::Dropout lstmDropout{nullptr};
+  torch::nn::Dropout hiddenDropout{nullptr};
+  torch::nn::Linear linear1{nullptr};
+  torch::nn::Linear linear2{nullptr};
+  torch::nn::LSTM contextLSTM{nullptr};
+  torch::nn::LSTM rawInputLSTM{nullptr};
+  std::vector<torch::nn::LSTM> lstms;
+
+  public :
+
+  LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput);
+  torch::Tensor forward(torch::Tensor input) override;
+  std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
+};
+
+#endif
diff --git a/torch_modules/src/LSTMNetwork.cpp b/torch_modules/src/LSTMNetwork.cpp
new file mode 100644
index 0000000..84c9e79
--- /dev/null
+++ b/torch_modules/src/LSTMNetwork.cpp
@@ -0,0 +1,221 @@
+#include "LSTMNetwork.hpp"
+
+LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
+{
+  constexpr int embeddingsSize = 64;
+  constexpr int hiddenSize = 1024;
+  constexpr int contextLSTMSize = 512;
+  constexpr int focusedLSTMSize = 64;
+
+  setLeftBorder(leftBorder);
+  setRightBorder(rightBorder);
+  setNbStackElements(nbStackElements);
+  setColumns(columns);
+
+  rawInputSize =  leftWindowRawInput + rightWindowRawInput + 1;
+  if (leftWindowRawInput < 0 or rightWindowRawInput < 0)
+    rawInputSize = 0;
+  else
+    rawInputLSTM = register_module("rawInputLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, focusedLSTMSize).batch_first(false).bidirectional(true)));
+
+  int rawInputLSTMOutputSize = rawInputSize == 0 ? 0 : (rawInputLSTM->options.hidden_size() * (rawInputLSTM->options.bidirectional() ? 4 : 1));
+
+  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
+  embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
+  lstmDropout = register_module("lstm_dropout", torch::nn::Dropout(0.3));
+  hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
+  contextLSTM = register_module("contextLSTM", torch::nn::LSTM(torch::nn::LSTMOptions(columns.size()*embeddingsSize, contextLSTMSize).batch_first(false).bidirectional(true)));
+
+  int totalLSTMOutputSize = contextLSTM->options.hidden_size() * (contextLSTM->options.bidirectional() ? 4 : 1) + rawInputLSTMOutputSize;
+
+  for (auto & col : focusedColumns)
+  {
+    lstms.emplace_back(register_module(fmt::format("LSTM_{}", col), torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, focusedLSTMSize).batch_first(false).bidirectional(true))));
+    totalLSTMOutputSize += lstms.back()->options.hidden_size() * (lstms.back()->options.bidirectional() ? 4 : 1) * (focusedBufferIndexes.size()+focusedStackIndexes.size());
+  }
+
+  linear1 = register_module("linear1", torch::nn::Linear(totalLSTMOutputSize, hiddenSize));
+  linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
+}
+
+torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
+{
+  if (input.dim() == 1)
+    input = input.unsqueeze(0);
+
+  auto embeddings = embeddingsDropout(wordEmbeddings(input));
+
+  auto context = embeddings.narrow(1, rawInputSize, columns.size()*(1+leftBorder+rightBorder));
+  context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()});
+
+  auto elementsEmbeddings = embeddings.narrow(1, rawInputSize+context.size(1), input.size(1)-(rawInputSize+context.size(1)));
+
+  context = context.permute({1,0,2});
+
+  std::vector<torch::Tensor> lstmOutputs;
+
+  if (rawInputSize != 0)
+  {
+    auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1).permute({1,0});
+    auto lstmOut = rawInputLSTM(rawLetters).output;
+    if (rawInputLSTM->options.bidirectional())
+      lstmOutputs.emplace_back(torch::cat({lstmOut[0],lstmOut[-1]}, 1));
+    else
+      lstmOutputs.emplace_back(lstmOut[-1]);
+  }
+
+  auto curIndex = 0;
+  for (unsigned int i = 0; i < focusedColumns.size(); i++)
+  {
+    long nbElements = maxNbElements[i];
+    for (unsigned int focused = 0; focused < focusedBufferIndexes.size()+focusedStackIndexes.size(); focused++)
+    {
+      auto lstmInput = elementsEmbeddings.narrow(1, curIndex, nbElements).permute({1,0,2});
+      curIndex += nbElements;
+      auto lstmOut = lstms[i](lstmInput).output;
+
+      if (lstms[i]->options.bidirectional())
+        lstmOutputs.emplace_back(torch::cat({lstmOut[0],lstmOut[-1]}, 1));
+      else
+        lstmOutputs.emplace_back(lstmOut[-1]);
+    }
+  }
+
+  auto lstmOut = contextLSTM(context).output;
+  if (contextLSTM->options.bidirectional())
+    lstmOutputs.emplace_back(torch::cat({lstmOut[0],lstmOut[-1]}, 1));
+  else
+    lstmOutputs.emplace_back(lstmOut[-1]);
+
+  auto totalInput = lstmDropout(torch::cat(lstmOutputs, 1));
+
+  return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
+}
+
+std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, Dict & dict) const
+{
+  if (dict.size() >= maxNbEmbeddings)
+    util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
+
+  std::vector<long> contextIndexes = extractContextIndexes(config);
+  std::vector<std::vector<long>> context;
+  context.emplace_back();
+
+  if (rawInputSize > 0)
+  {
+    for (int i = 0; i < leftWindowRawInput; i++)
+      if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
+        context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
+      else
+        context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+
+    for (int i = 0; i <= rightWindowRawInput; i++)
+      if (config.hasCharacter(config.getCharacterIndex()+i))
+        context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i))));
+      else
+        context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+  }
+
+  for (auto index : contextIndexes)
+    for (auto & col : columns)
+      if (index == -1)
+        for (auto & contextElement : context)
+          contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+      else
+      {
+        int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
+
+        for (auto & contextElement : context)
+          contextElement.push_back(dictIndex);
+
+        if (is_training())
+          if (col == "FORM" || col == "LEMMA")
+            if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
+            {
+              context.emplace_back(context.back());
+              context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr);
+            }
+      }
+
+  for (auto & contextElement : context)
+    for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
+    {
+      auto & col = focusedColumns[colIndex];
+
+      std::vector<int> focusedIndexes;
+      for (auto relIndex : focusedBufferIndexes)
+      {
+        int index = relIndex + leftBorder;
+        if (index < 0 || index >= (int)contextIndexes.size())
+          focusedIndexes.push_back(-1);
+        else
+          focusedIndexes.push_back(contextIndexes[index]);
+      }
+      for (auto index : focusedStackIndexes)
+      {
+        if (!config.hasStack(index))
+          focusedIndexes.push_back(-1);
+        else if (!config.has(col, config.getStack(index), 0))
+          focusedIndexes.push_back(-1);
+        else
+          focusedIndexes.push_back(config.getStack(index));
+      }
+
+      for (auto index : focusedIndexes)
+      {
+        if (index == -1)
+        {
+          for (int i = 0; i < maxNbElements[colIndex]; i++)
+            contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+          continue;
+        }
+
+        std::vector<std::string> elements;
+        if (col == "FORM")
+        {
+          auto asUtf8 = util::splitAsUtf8(config.getAsFeature(col, index).get());
+
+          for (int i = 0; i < maxNbElements[colIndex]; i++)
+            if (i < (int)asUtf8.size())
+              elements.emplace_back(fmt::format("Letter({})", asUtf8[i]));
+            else
+              elements.emplace_back(Dict::nullValueStr);
+        }
+        else if (col == "FEATS")
+        {
+          auto splited = util::split(config.getAsFeature(col, index).get(), '|');
+
+          for (int i = 0; i < maxNbElements[colIndex]; i++)
+            if (i < (int)splited.size())
+              elements.emplace_back(fmt::format("FEATS({})", splited[i]));
+            else
+              elements.emplace_back(Dict::nullValueStr);
+        }
+        else if (col == "ID")
+        {
+          if (config.isTokenPredicted(index))
+            elements.emplace_back("ID(TOKEN)");
+          else if (config.isMultiwordPredicted(index))
+            elements.emplace_back("ID(MULTIWORD)");
+          else if (config.isEmptyNodePredicted(index))
+            elements.emplace_back("ID(EMPTYNODE)");
+        }
+        else
+        {
+          elements.emplace_back(config.getAsFeature(col, index));
+        }
+
+        if ((int)elements.size() != maxNbElements[colIndex])
+          util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col));
+
+        for (auto & element : elements)
+          contextElement.emplace_back(dict.getIndexOrInsert(element));
+      }
+    }
+
+  if (!is_training() && context.size() > 1)
+    util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size()));
+
+  return context;
+}
+
-- 
GitLab