From 2d940e4739900e1b31b0fd0bd3b8bae577d4b7f8 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 29 Jun 2020 16:56:30 +0200 Subject: [PATCH] Added cnn --- torch_modules/include/CNN.hpp | 13 ++++++------- torch_modules/include/HistoryModule.hpp | 1 + torch_modules/src/CNN.cpp | 22 +++++++++++----------- torch_modules/src/HistoryModule.cpp | 2 ++ 4 files changed, 20 insertions(+), 18 deletions(-) diff --git a/torch_modules/include/CNN.hpp b/torch_modules/include/CNN.hpp index 2c8431c..776bafa 100644 --- a/torch_modules/include/CNN.hpp +++ b/torch_modules/include/CNN.hpp @@ -2,22 +2,21 @@ #define CNN__H #include <torch/torch.h> +#include "MyModule.hpp" -class CNNImpl : public torch::nn::Module +class CNNImpl : public MyModule { private : - std::vector<int> windowSizes; std::vector<torch::nn::Conv2d> CNNs; - int nbFilters; - int elementSize; + std::vector<int> windowSizes{2, 3}; + int outputSize; public : - CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize); + CNNImpl(int inputSize, int outputSize, ModuleOptions options); torch::Tensor forward(torch::Tensor input); - std::size_t getOutputSize(); - + int getOutputSize(int sequenceLength); }; TORCH_MODULE(CNN); diff --git a/torch_modules/include/HistoryModule.hpp b/torch_modules/include/HistoryModule.hpp index 4a0a2bb..0489114 100644 --- a/torch_modules/include/HistoryModule.hpp +++ b/torch_modules/include/HistoryModule.hpp @@ -6,6 +6,7 @@ #include "MyModule.hpp" #include "LSTM.hpp" #include "GRU.hpp" +#include "CNN.hpp" #include "Concat.hpp" class HistoryModuleImpl : public Submodule diff --git a/torch_modules/src/CNN.cpp b/torch_modules/src/CNN.cpp index 35f357e..2aaffec 100644 --- a/torch_modules/src/CNN.cpp +++ b/torch_modules/src/CNN.cpp @@ -1,34 +1,34 @@ #include "CNN.hpp" #include "fmt/core.h" -CNNImpl::CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize) - : windowSizes(windowSizes), nbFilters(nbFilters), elementSize(elementSize) +CNNImpl::CNNImpl(int inputSize, int outputSize, ModuleOptions options) : outputSize(outputSize) { for (auto & windowSize : windowSizes) { std::string moduleName = fmt::format("cnn_window_{}", windowSize); - CNNs.emplace_back(register_module(moduleName, torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFilters, torch::ExpandingArray<2>({windowSize,elementSize})).padding({windowSize-1, 0})))); + auto kernel = torch::ExpandingArray<2>({windowSize, inputSize}); + auto opts = torch::nn::Conv2dOptions(1, outputSize, kernel).padding({windowSize-1, 0}); + CNNs.emplace_back(register_module(moduleName, torch::nn::Conv2d(opts))); } } torch::Tensor CNNImpl::forward(torch::Tensor input) { std::vector<torch::Tensor> windows; + input = input.unsqueeze(1); + for (unsigned int i = 0; i < CNNs.size(); i++) { - auto convOut = torch::relu(CNNs[i](input).squeeze(-1)); - auto pooled = torch::max_pool1d(convOut, convOut.size(2)); + auto convOut = CNNs[i](input).squeeze(-1); + auto pooled = torch::max_pool1d(convOut, convOut.size(-1)); windows.emplace_back(pooled); } - auto cnnOut = torch::cat(windows, 2); - cnnOut = cnnOut.view({cnnOut.size(0), -1}); - - return cnnOut; + return torch::cat(windows, -1).view({input.size(0), -1}); } -std::size_t CNNImpl::getOutputSize() +int CNNImpl::getOutputSize(int) { - return windowSizes.size()*nbFilters; + return outputSize*windowSizes.size(); } diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp index 3e09b0a..eb5c28c 100644 --- a/torch_modules/src/HistoryModule.cpp +++ b/torch_modules/src/HistoryModule.cpp @@ -26,6 +26,8 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin myModule = register_module("myModule", LSTM(inSize, outSize, options)); else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(inSize, outSize, options)); + else if (subModuleType == "CNN") + myModule = register_module("myModule", CNN(inSize, outSize, options)); else if (subModuleType == "Concat") myModule = register_module("myModule", Concat(inSize)); else -- GitLab