From b50c6ff365185c40ee4524b2a6a5530cee7cb4a0 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 4 Mar 2020 13:37:34 +0100
Subject: [PATCH] Made a CNN module

---
 torch_modules/include/CNN.hpp        | 26 +++++++++++++++++
 torch_modules/include/CNNNetwork.hpp |  6 ++--
 torch_modules/src/CNN.cpp            | 34 ++++++++++++++++++++++
 torch_modules/src/CNNNetwork.cpp     | 43 ++++++----------------------
 4 files changed, 72 insertions(+), 37 deletions(-)
 create mode 100644 torch_modules/include/CNN.hpp
 create mode 100644 torch_modules/src/CNN.cpp

diff --git a/torch_modules/include/CNN.hpp b/torch_modules/include/CNN.hpp
new file mode 100644
index 0000000..e08a869
--- /dev/null
+++ b/torch_modules/include/CNN.hpp
@@ -0,0 +1,26 @@
+#ifndef CNN__H
+#define CNN__H
+
+#include <torch/torch.h>
+#include "fmt/core.h"
+
+class CNNImpl : public torch::nn::Module
+{
+  private :
+
+  std::vector<long> windowSizes;
+  std::vector<torch::nn::Conv2d> CNNs;
+  int nbFilters;
+  int elementSize;
+
+  public :
+
+  CNNImpl(std::vector<long> windowSizes, int nbFilters, int elementSize);
+  torch::Tensor forward(torch::Tensor input);
+  int getOutputSize();
+
+};
+TORCH_MODULE(CNN);
+
+#endif
+
diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp
index b6b5cef..0893ff9 100644
--- a/torch_modules/include/CNNNetwork.hpp
+++ b/torch_modules/include/CNNNetwork.hpp
@@ -2,12 +2,12 @@
 #define CNNNETWORK__H
 
 #include "NeuralNetwork.hpp"
+#include "CNN.hpp"
 
 class CNNNetworkImpl : public NeuralNetworkImpl
 {
   private :
 
-  static inline std::vector<long> windowSizes{2,3,4};
   static constexpr unsigned int maxNbLetters = 10;
 
   private :
@@ -19,8 +19,8 @@ class CNNNetworkImpl : public NeuralNetworkImpl
   torch::nn::Embedding wordEmbeddings{nullptr};
   torch::nn::Linear linear1{nullptr};
   torch::nn::Linear linear2{nullptr};
-  std::vector<torch::nn::Conv2d> CNNs;
-  std::vector<torch::nn::Conv2d> lettersCNNs;
+  CNN contextCNN{nullptr};
+  CNN lettersCNN{nullptr};
 
   public :
 
diff --git a/torch_modules/src/CNN.cpp b/torch_modules/src/CNN.cpp
new file mode 100644
index 0000000..f033403
--- /dev/null
+++ b/torch_modules/src/CNN.cpp
@@ -0,0 +1,34 @@
+#include "CNN.hpp"
+#include "CNN.hpp"
+
+CNNImpl::CNNImpl(std::vector<long> windowSizes, int nbFilters, int elementSize)
+  : windowSizes(windowSizes), nbFilters(nbFilters), elementSize(elementSize)
+{
+  for (auto & windowSize : windowSizes)
+  {
+    std::string moduleName = fmt::format("cnn_window_{}", windowSize);
+    CNNs.emplace_back(register_module(moduleName, torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFilters, torch::ExpandingArray<2>({windowSize,elementSize})).padding({windowSize-1, 0}))));
+  }
+}
+
+torch::Tensor CNNImpl::forward(torch::Tensor input)
+{
+  std::vector<torch::Tensor> windows;
+  for (unsigned int i = 0; i < CNNs.size(); i++)
+  {
+    auto convOut = torch::relu(CNNs[i](input).squeeze(-1));
+    auto pooled = torch::max_pool1d(convOut, convOut.size(2));
+    windows.emplace_back(pooled);
+  }
+
+  auto cnnOut = torch::cat(windows, 2);
+  cnnOut = cnnOut.view({cnnOut.size(0), -1});
+
+  return cnnOut;
+}
+
+int CNNImpl::getOutputSize()
+{
+  return windowSizes.size()*nbFilters;
+}
+
diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp
index 50f9c0d..86781c9 100644
--- a/torch_modules/src/CNNNetwork.cpp
+++ b/torch_modules/src/CNNNetwork.cpp
@@ -13,13 +13,10 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
   setColumns(columns);
 
   wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
-  linear1 = register_module("linear1", torch::nn::Linear(nbFilters*windowSizes.size()+nbFiltersLetters*windowSizes.size()*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize));
+  contextCNN = register_module("contextCNN", CNN(std::vector<long>{2,3,4}, nbFilters, 2*embeddingsSize));
+  lettersCNN = register_module("lettersCNN", CNN(std::vector<long>{2,3,4,5}, nbFiltersLetters, embeddingsSize));
+  linear1 = register_module("linear1", torch::nn::Linear(contextCNN->getOutputSize()+lettersCNN->getOutputSize()*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize));
   linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
-  for (auto & windowSize : windowSizes)
-  {
-    CNNs.emplace_back(register_module(fmt::format("cnn_context_{}", windowSize), torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFilters, torch::ExpandingArray<2>({windowSize,2*embeddingsSize})).padding({windowSize-1, 0}))));
-    lettersCNNs.emplace_back(register_module(fmt::format("cnn_letters_{}", windowSize), torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFiltersLetters, torch::ExpandingArray<2>({windowSize,embeddingsSize})).padding({windowSize-1, 0}))));
-  }
 }
 
 torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
@@ -34,38 +31,16 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
   auto lettersEmbeddings = wordEmbeddings(wordLetters).view({wordLetters.size(0), wordLetters.size(1)/maxNbLetters, maxNbLetters, wordEmbeddings->options.embedding_dim()}).unsqueeze(1);
 
   auto permuted = lettersEmbeddings.permute({2,0,1,3,4});
-  std::vector<torch::Tensor> windows;
+  std::vector<torch::Tensor> cnnOuts;
   for (unsigned int word = 0; word < focusedBufferIndexes.size(); word++)
-    for (unsigned int i = 0; i < lettersCNNs.size(); i++)
-    {
-      auto input = permuted[word];
-      auto convOut = torch::relu(lettersCNNs[i](input).squeeze(-1));
-      auto pooled = torch::max_pool1d(convOut, convOut.size(2));
-      windows.emplace_back(pooled);
-    }
+    cnnOuts.emplace_back(lettersCNN(permuted[word]));
   for (unsigned int word = 0; word < focusedStackIndexes.size(); word++)
-    for (unsigned int i = 0; i < lettersCNNs.size(); i++)
-    {
-      auto input = permuted[focusedBufferIndexes.size()+word];
-      auto convOut = torch::relu(lettersCNNs[i](input).squeeze(-1));
-      auto pooled = torch::max_pool1d(convOut, convOut.size(2));
-      windows.emplace_back(pooled);
-    }
-  auto lettersCnnOut = torch::cat(windows, 2);
-  lettersCnnOut = lettersCnnOut.view({lettersCnnOut.size(0), -1});
-
-  windows.clear();
-  for (unsigned int i = 0; i < CNNs.size(); i++)
-  {
-    auto convOut = torch::relu(CNNs[i](embeddings).squeeze(-1));
-    auto pooled = torch::max_pool1d(convOut, convOut.size(2));
-    windows.emplace_back(pooled);
-  }
+    cnnOuts.emplace_back(lettersCNN(permuted[word]));
+  auto lettersCnnOut = torch::cat(cnnOuts, 1);
 
-  auto cnnOut = torch::cat(windows, 2);
-  cnnOut = cnnOut.view({cnnOut.size(0), -1});
+  auto contextCnnOut = contextCNN(embeddings);
 
-  auto totalInput = torch::cat({cnnOut, lettersCnnOut}, 1);
+  auto totalInput = torch::cat({contextCnnOut, lettersCnnOut}, 1);
 
   return linear2(torch::relu(linear1(totalInput)));
 }
-- 
GitLab