From 9db9288677d9fd53b32b5dc19fa4d34c9ae017dd Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sat, 7 Mar 2020 22:35:06 +0100
Subject: [PATCH] Added rawInput to CNNNetwork

---
 torch_modules/include/CNNNetwork.hpp |  3 +++
 torch_modules/src/CNNNetwork.cpp     | 31 +++++++++++++++++++++++-----
 2 files changed, 29 insertions(+), 5 deletions(-)

diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp
index ebbfefd..22f4b0c 100644
--- a/torch_modules/include/CNNNetwork.hpp
+++ b/torch_modules/include/CNNNetwork.hpp
@@ -12,11 +12,14 @@ class CNNNetworkImpl : public NeuralNetworkImpl
   std::vector<int> focusedStackIndexes;
   std::vector<std::string> focusedColumns;
   std::vector<int> maxNbElements;
+  int leftWindowRawInput{5};
+  int rightWindowRawInput{5};
 
   torch::nn::Embedding wordEmbeddings{nullptr};
   torch::nn::Linear linear1{nullptr};
   torch::nn::Linear linear2{nullptr};
   CNN contextCNN{nullptr};
+  CNN rawInputCNN{nullptr};
   std::vector<CNN> cnns;
 
   public :
diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp
index 2981b74..7d2a4fc 100644
--- a/torch_modules/src/CNNNetwork.cpp
+++ b/torch_modules/src/CNNNetwork.cpp
@@ -14,7 +14,8 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
 
   wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
   contextCNN = register_module("contextCNN", CNN(std::vector<int>{2,3,4}, nbFiltersContext, columns.size()*embeddingsSize));
-  int totalCnnOutputSize = contextCNN->getOutputSize();
+  rawInputCNN = register_module("rawInputCNN", CNN(std::vector<int>{2,3,4}, nbFiltersFocused, embeddingsSize));
+  int totalCnnOutputSize = contextCNN->getOutputSize()+rawInputCNN->getOutputSize();
   for (auto & col : focusedColumns)
   {
     std::vector<int> windows{2,3,4};
@@ -30,11 +31,19 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
   if (input.dim() == 1)
     input = input.unsqueeze(0);
 
-  auto wordIndexes = input.narrow(1, 0, columns.size()*(1+leftBorder+rightBorder));
+  auto embeddings = wordEmbeddings(input);
 
-  auto elementsEmbeddings = wordEmbeddings(input.narrow(1, wordIndexes.size(1), input.size(1)-wordIndexes.size(1)));
+  auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1);
+
+  auto context = embeddings.narrow(1, rawLetters.size(0), columns.size()*(1+leftBorder+rightBorder));
+  context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()});
+
+  auto elementsEmbeddings = embeddings.narrow(1, rawLetters.size(1)+context.size(1), input.size(1)-(rawLetters.size(1)+context.size(1)));
 
   std::vector<torch::Tensor> cnnOutputs;
+
+  cnnOutputs.emplace_back(rawInputCNN(rawLetters.unsqueeze(1)));
+
   auto curIndex = 0;
   for (unsigned int i = 0; i < focusedColumns.size(); i++)
   {
@@ -47,8 +56,7 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
     }
   }
 
-  auto embeddings = wordEmbeddings(wordIndexes).view({wordIndexes.size(0), wordIndexes.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}).unsqueeze(1);
-  cnnOutputs.emplace_back(contextCNN(embeddings));
+  cnnOutputs.emplace_back(contextCNN(context.unsqueeze(1)));
 
   auto totalInput = torch::cat(cnnOutputs, 1);
 
@@ -60,6 +68,19 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c
   std::vector<long> contextIndexes = extractContextIndexes(config);
   std::vector<long> context;
 
+  for (int i = 0; i < leftWindowRawInput; i++)
+    if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
+      context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
+    else
+      context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+
+  for (int i = 0; i <= rightWindowRawInput; i++)
+    if (config.hasCharacter(config.getCharacterIndex()+i))
+
+      context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i))));
+    else
+      context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
+
   for (auto index : contextIndexes)
     for (auto & col : columns)
       if (index == -1)
-- 
GitLab