From b50c6ff365185c40ee4524b2a6a5530cee7cb4a0 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 4 Mar 2020 13:37:34 +0100 Subject: [PATCH] Made a CNN module --- torch_modules/include/CNN.hpp | 26 +++++++++++++++++ torch_modules/include/CNNNetwork.hpp | 6 ++-- torch_modules/src/CNN.cpp | 34 ++++++++++++++++++++++ torch_modules/src/CNNNetwork.cpp | 43 ++++++---------------------- 4 files changed, 72 insertions(+), 37 deletions(-) create mode 100644 torch_modules/include/CNN.hpp create mode 100644 torch_modules/src/CNN.cpp diff --git a/torch_modules/include/CNN.hpp b/torch_modules/include/CNN.hpp new file mode 100644 index 0000000..e08a869 --- /dev/null +++ b/torch_modules/include/CNN.hpp @@ -0,0 +1,26 @@ +#ifndef CNN__H +#define CNN__H + +#include <torch/torch.h> +#include "fmt/core.h" + +class CNNImpl : public torch::nn::Module +{ + private : + + std::vector<long> windowSizes; + std::vector<torch::nn::Conv2d> CNNs; + int nbFilters; + int elementSize; + + public : + + CNNImpl(std::vector<long> windowSizes, int nbFilters, int elementSize); + torch::Tensor forward(torch::Tensor input); + int getOutputSize(); + +}; +TORCH_MODULE(CNN); + +#endif + diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp index b6b5cef..0893ff9 100644 --- a/torch_modules/include/CNNNetwork.hpp +++ b/torch_modules/include/CNNNetwork.hpp @@ -2,12 +2,12 @@ #define CNNNETWORK__H #include "NeuralNetwork.hpp" +#include "CNN.hpp" class CNNNetworkImpl : public NeuralNetworkImpl { private : - static inline std::vector<long> windowSizes{2,3,4}; static constexpr unsigned int maxNbLetters = 10; private : @@ -19,8 +19,8 @@ class CNNNetworkImpl : public NeuralNetworkImpl torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; - std::vector<torch::nn::Conv2d> CNNs; - std::vector<torch::nn::Conv2d> lettersCNNs; + CNN contextCNN{nullptr}; + CNN lettersCNN{nullptr}; public : diff --git a/torch_modules/src/CNN.cpp b/torch_modules/src/CNN.cpp new file mode 100644 index 0000000..f033403 --- /dev/null +++ b/torch_modules/src/CNN.cpp @@ -0,0 +1,34 @@ +#include "CNN.hpp" +#include "CNN.hpp" + +CNNImpl::CNNImpl(std::vector<long> windowSizes, int nbFilters, int elementSize) + : windowSizes(windowSizes), nbFilters(nbFilters), elementSize(elementSize) +{ + for (auto & windowSize : windowSizes) + { + std::string moduleName = fmt::format("cnn_window_{}", windowSize); + CNNs.emplace_back(register_module(moduleName, torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFilters, torch::ExpandingArray<2>({windowSize,elementSize})).padding({windowSize-1, 0})))); + } +} + +torch::Tensor CNNImpl::forward(torch::Tensor input) +{ + std::vector<torch::Tensor> windows; + for (unsigned int i = 0; i < CNNs.size(); i++) + { + auto convOut = torch::relu(CNNs[i](input).squeeze(-1)); + auto pooled = torch::max_pool1d(convOut, convOut.size(2)); + windows.emplace_back(pooled); + } + + auto cnnOut = torch::cat(windows, 2); + cnnOut = cnnOut.view({cnnOut.size(0), -1}); + + return cnnOut; +} + +int CNNImpl::getOutputSize() +{ + return windowSizes.size()*nbFilters; +} + diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 50f9c0d..86781c9 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -13,13 +13,10 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i setColumns(columns); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); - linear1 = register_module("linear1", torch::nn::Linear(nbFilters*windowSizes.size()+nbFiltersLetters*windowSizes.size()*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize)); + contextCNN = register_module("contextCNN", CNN(std::vector<long>{2,3,4}, nbFilters, 2*embeddingsSize)); + lettersCNN = register_module("lettersCNN", CNN(std::vector<long>{2,3,4,5}, nbFiltersLetters, embeddingsSize)); + linear1 = register_module("linear1", torch::nn::Linear(contextCNN->getOutputSize()+lettersCNN->getOutputSize()*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); - for (auto & windowSize : windowSizes) - { - CNNs.emplace_back(register_module(fmt::format("cnn_context_{}", windowSize), torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFilters, torch::ExpandingArray<2>({windowSize,2*embeddingsSize})).padding({windowSize-1, 0})))); - lettersCNNs.emplace_back(register_module(fmt::format("cnn_letters_{}", windowSize), torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFiltersLetters, torch::ExpandingArray<2>({windowSize,embeddingsSize})).padding({windowSize-1, 0})))); - } } torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) @@ -34,38 +31,16 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) auto lettersEmbeddings = wordEmbeddings(wordLetters).view({wordLetters.size(0), wordLetters.size(1)/maxNbLetters, maxNbLetters, wordEmbeddings->options.embedding_dim()}).unsqueeze(1); auto permuted = lettersEmbeddings.permute({2,0,1,3,4}); - std::vector<torch::Tensor> windows; + std::vector<torch::Tensor> cnnOuts; for (unsigned int word = 0; word < focusedBufferIndexes.size(); word++) - for (unsigned int i = 0; i < lettersCNNs.size(); i++) - { - auto input = permuted[word]; - auto convOut = torch::relu(lettersCNNs[i](input).squeeze(-1)); - auto pooled = torch::max_pool1d(convOut, convOut.size(2)); - windows.emplace_back(pooled); - } + cnnOuts.emplace_back(lettersCNN(permuted[word])); for (unsigned int word = 0; word < focusedStackIndexes.size(); word++) - for (unsigned int i = 0; i < lettersCNNs.size(); i++) - { - auto input = permuted[focusedBufferIndexes.size()+word]; - auto convOut = torch::relu(lettersCNNs[i](input).squeeze(-1)); - auto pooled = torch::max_pool1d(convOut, convOut.size(2)); - windows.emplace_back(pooled); - } - auto lettersCnnOut = torch::cat(windows, 2); - lettersCnnOut = lettersCnnOut.view({lettersCnnOut.size(0), -1}); - - windows.clear(); - for (unsigned int i = 0; i < CNNs.size(); i++) - { - auto convOut = torch::relu(CNNs[i](embeddings).squeeze(-1)); - auto pooled = torch::max_pool1d(convOut, convOut.size(2)); - windows.emplace_back(pooled); - } + cnnOuts.emplace_back(lettersCNN(permuted[word])); + auto lettersCnnOut = torch::cat(cnnOuts, 1); - auto cnnOut = torch::cat(windows, 2); - cnnOut = cnnOut.view({cnnOut.size(0), -1}); + auto contextCnnOut = contextCNN(embeddings); - auto totalInput = torch::cat({cnnOut, lettersCnnOut}, 1); + auto totalInput = torch::cat({contextCnnOut, lettersCnnOut}, 1); return linear2(torch::relu(linear1(totalInput))); } -- GitLab