From 2d940e4739900e1b31b0fd0bd3b8bae577d4b7f8 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 29 Jun 2020 16:56:30 +0200
Subject: [PATCH] Added cnn

---
 torch_modules/include/CNN.hpp           | 13 ++++++-------
 torch_modules/include/HistoryModule.hpp |  1 +
 torch_modules/src/CNN.cpp               | 22 +++++++++++-----------
 torch_modules/src/HistoryModule.cpp     |  2 ++
 4 files changed, 20 insertions(+), 18 deletions(-)

diff --git a/torch_modules/include/CNN.hpp b/torch_modules/include/CNN.hpp
index 2c8431c..776bafa 100644
--- a/torch_modules/include/CNN.hpp
+++ b/torch_modules/include/CNN.hpp
@@ -2,22 +2,21 @@
 #define CNN__H
 
 #include <torch/torch.h>
+#include "MyModule.hpp"
 
-class CNNImpl : public torch::nn::Module
+class CNNImpl : public MyModule
 {
   private :
 
-  std::vector<int> windowSizes;
   std::vector<torch::nn::Conv2d> CNNs;
-  int nbFilters;
-  int elementSize;
+  std::vector<int> windowSizes{2, 3};
+  int outputSize;
 
   public :
 
-  CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize);
+  CNNImpl(int inputSize, int outputSize, ModuleOptions options);
   torch::Tensor forward(torch::Tensor input);
-  std::size_t getOutputSize();
-
+  int getOutputSize(int sequenceLength);
 };
 TORCH_MODULE(CNN);
 
diff --git a/torch_modules/include/HistoryModule.hpp b/torch_modules/include/HistoryModule.hpp
index 4a0a2bb..0489114 100644
--- a/torch_modules/include/HistoryModule.hpp
+++ b/torch_modules/include/HistoryModule.hpp
@@ -6,6 +6,7 @@
 #include "MyModule.hpp"
 #include "LSTM.hpp"
 #include "GRU.hpp"
+#include "CNN.hpp"
 #include "Concat.hpp"
 
 class HistoryModuleImpl : public Submodule
diff --git a/torch_modules/src/CNN.cpp b/torch_modules/src/CNN.cpp
index 35f357e..2aaffec 100644
--- a/torch_modules/src/CNN.cpp
+++ b/torch_modules/src/CNN.cpp
@@ -1,34 +1,34 @@
 #include "CNN.hpp"
 #include "fmt/core.h"
 
-CNNImpl::CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize)
-  : windowSizes(windowSizes), nbFilters(nbFilters), elementSize(elementSize)
+CNNImpl::CNNImpl(int inputSize, int outputSize, ModuleOptions options) : outputSize(outputSize)
 {
   for (auto & windowSize : windowSizes)
   {
     std::string moduleName = fmt::format("cnn_window_{}", windowSize);
-    CNNs.emplace_back(register_module(moduleName, torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFilters, torch::ExpandingArray<2>({windowSize,elementSize})).padding({windowSize-1, 0}))));
+    auto kernel = torch::ExpandingArray<2>({windowSize, inputSize});
+    auto opts = torch::nn::Conv2dOptions(1, outputSize, kernel).padding({windowSize-1, 0});
+    CNNs.emplace_back(register_module(moduleName, torch::nn::Conv2d(opts)));
   }
 }
 
 torch::Tensor CNNImpl::forward(torch::Tensor input)
 {
   std::vector<torch::Tensor> windows;
+  input = input.unsqueeze(1);
+
   for (unsigned int i = 0; i < CNNs.size(); i++)
   {
-    auto convOut = torch::relu(CNNs[i](input).squeeze(-1));
-    auto pooled = torch::max_pool1d(convOut, convOut.size(2));
+    auto convOut = CNNs[i](input).squeeze(-1);
+    auto pooled = torch::max_pool1d(convOut, convOut.size(-1));
     windows.emplace_back(pooled);
   }
 
-  auto cnnOut = torch::cat(windows, 2);
-  cnnOut = cnnOut.view({cnnOut.size(0), -1});
-
-  return cnnOut;
+  return torch::cat(windows, -1).view({input.size(0), -1});
 }
 
-std::size_t CNNImpl::getOutputSize()
+int CNNImpl::getOutputSize(int)
 {
-  return windowSizes.size()*nbFilters;
+  return outputSize*windowSizes.size();
 }
 
diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp
index 3e09b0a..eb5c28c 100644
--- a/torch_modules/src/HistoryModule.cpp
+++ b/torch_modules/src/HistoryModule.cpp
@@ -26,6 +26,8 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin
               myModule = register_module("myModule", LSTM(inSize, outSize, options));
             else if (subModuleType == "GRU")
               myModule = register_module("myModule", GRU(inSize, outSize, options));
+            else if (subModuleType == "CNN")
+              myModule = register_module("myModule", CNN(inSize, outSize, options));
             else if (subModuleType == "Concat")
               myModule = register_module("myModule", Concat(inSize));
             else
-- 
GitLab