From ba5742bdfc479346305904d637e6eaa9bf42bde5 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sun, 8 Mar 2020 19:16:48 +0100
Subject: [PATCH] Added rawInput window parameters to CNNNetwork

---
 reading_machine/src/Classifier.cpp   |  6 ++--
 torch_modules/include/CNNNetwork.hpp |  7 +++--
 torch_modules/src/CNNNetwork.cpp     | 47 +++++++++++++++++-----------
 3 files changed, 36 insertions(+), 24 deletions(-)

diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index e5eee2d..3e0cc65 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -48,8 +48,8 @@ void Classifier::initNeuralNetwork(const std::string & topology)
       }
     },
     {
-      std::regex("CNN\\((\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\)"),
-      "CNN(leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements}) : CNN to capture context.",
+      std::regex("CNN\\((\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
+      "CNN(leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
       [this,topology](auto sm)
       {
         std::vector<int> focusedBuffer, focusedStack, maxNbElements;
@@ -66,7 +66,7 @@ void Classifier::initNeuralNetwork(const std::string & topology)
           maxNbElements.push_back(std::stoi(std::string(s)));
         if (focusedColumns.size() != maxNbElements.size())
           util::myThrow("focusedColumns.size() != maxNbElements.size()");
-        this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements));
+        this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm[9]), std::stoi(sm[10])));
       }
     },
     {
diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp
index 22f4b0c..f193ebc 100644
--- a/torch_modules/include/CNNNetwork.hpp
+++ b/torch_modules/include/CNNNetwork.hpp
@@ -12,8 +12,9 @@ class CNNNetworkImpl : public NeuralNetworkImpl
   std::vector<int> focusedStackIndexes;
   std::vector<std::string> focusedColumns;
   std::vector<int> maxNbElements;
-  int leftWindowRawInput{5};
-  int rightWindowRawInput{5};
+  int leftWindowRawInput;
+  int rightWindowRawInput;
+  int rawInputSize;
 
   torch::nn::Embedding wordEmbeddings{nullptr};
   torch::nn::Linear linear1{nullptr};
@@ -24,7 +25,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl
 
   public :
 
-  CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements);
+  CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput);
   torch::Tensor forward(torch::Tensor input) override;
   std::vector<long> extractContext(Config & config, Dict & dict) const override;
 };
diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp
index 5ecccae..2f0cc9f 100644
--- a/torch_modules/src/CNNNetwork.cpp
+++ b/torch_modules/src/CNNNetwork.cpp
@@ -1,6 +1,6 @@
 #include "CNNNetwork.hpp"
 
-CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements)
+CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
 {
   constexpr int embeddingsSize = 64;
   constexpr int hiddenSize = 512;
@@ -12,10 +12,16 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
   setNbStackElements(nbStackElements);
   setColumns(columns);
 
+  rawInputSize =  leftWindowRawInput + rightWindowRawInput + 1;
+  if (leftWindowRawInput < 0 or rightWindowRawInput < 0)
+    rawInputSize = 0;
+  else
+    rawInputCNN = register_module("rawInputCNN", CNN(std::vector<int>{2,3,4}, nbFiltersFocused, embeddingsSize));
+  int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize();
+
   wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
   contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize));
-  rawInputCNN = register_module("rawInputCNN", CNN(std::vector<int>{2,3,4}, nbFiltersFocused, embeddingsSize));
-  int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNN->getOutputSize();
+  int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNNOutputSize;
   for (auto & col : focusedColumns)
   {
     std::vector<int> windows{2,3,4};
@@ -33,16 +39,18 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
 
   auto embeddings = wordEmbeddings(input);
 
-  auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1);
-
-  auto context = embeddings.narrow(1, rawLetters.size(1), columns.size()*(1+leftBorder+rightBorder));
+  auto context = embeddings.narrow(1, rawInputSize, columns.size()*(1+leftBorder+rightBorder));
   context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()});
 
-  auto elementsEmbeddings = embeddings.narrow(1, rawLetters.size(1)+context.size(1), input.size(1)-(rawLetters.size(1)+context.size(1)));
+  auto elementsEmbeddings = embeddings.narrow(1, rawInputSize+context.size(1), input.size(1)-(rawInputSize+context.size(1)));
 
   std::vector<torch::Tensor> cnnOutputs;
 
-  cnnOutputs.emplace_back(rawInputCNN(rawLetters.unsqueeze(1)));
+  if (rawInputSize != 0)
+  {
+    auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1);
+    cnnOutputs.emplace_back(rawInputCNN(rawLetters.unsqueeze(1)));
+  }
 
   auto curIndex = 0;
   for (unsigned int i = 0; i < focusedColumns.size(); i++)
@@ -68,18 +76,21 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c
   std::vector<long> contextIndexes = extractContextIndexes(config);
   std::vector<long> context;
 
-  for (int i = 0; i < leftWindowRawInput; i++)
-    if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
-      context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
-    else
-      context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+  if (rawInputSize > 0)
+  {
+    for (int i = 0; i < leftWindowRawInput; i++)
+      if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
+        context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
+      else
+        context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
 
-  for (int i = 0; i <= rightWindowRawInput; i++)
-    if (config.hasCharacter(config.getCharacterIndex()+i))
+    for (int i = 0; i <= rightWindowRawInput; i++)
+      if (config.hasCharacter(config.getCharacterIndex()+i))
 
-      context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i))));
-    else
-      context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+        context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i))));
+      else
+        context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+  }
 
   for (auto index : contextIndexes)
     for (auto & col : columns)
-- 
GitLab