Skip to content
Snippets Groups Projects
Commit b50c6ff3 authored by Franck Dary's avatar Franck Dary
Browse files

Made a CNN module

parent 69e871aa
No related branches found
No related tags found
No related merge requests found
#ifndef CNN__H
#define CNN__H
#include <torch/torch.h>
#include "fmt/core.h"
class CNNImpl : public torch::nn::Module
{
private :
std::vector<long> windowSizes;
std::vector<torch::nn::Conv2d> CNNs;
int nbFilters;
int elementSize;
public :
CNNImpl(std::vector<long> windowSizes, int nbFilters, int elementSize);
torch::Tensor forward(torch::Tensor input);
int getOutputSize();
};
TORCH_MODULE(CNN);
#endif
...@@ -2,12 +2,12 @@ ...@@ -2,12 +2,12 @@
#define CNNNETWORK__H #define CNNNETWORK__H
#include "NeuralNetwork.hpp" #include "NeuralNetwork.hpp"
#include "CNN.hpp"
class CNNNetworkImpl : public NeuralNetworkImpl class CNNNetworkImpl : public NeuralNetworkImpl
{ {
private : private :
static inline std::vector<long> windowSizes{2,3,4};
static constexpr unsigned int maxNbLetters = 10; static constexpr unsigned int maxNbLetters = 10;
private : private :
...@@ -19,8 +19,8 @@ class CNNNetworkImpl : public NeuralNetworkImpl ...@@ -19,8 +19,8 @@ class CNNNetworkImpl : public NeuralNetworkImpl
torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr}; torch::nn::Linear linear2{nullptr};
std::vector<torch::nn::Conv2d> CNNs; CNN contextCNN{nullptr};
std::vector<torch::nn::Conv2d> lettersCNNs; CNN lettersCNN{nullptr};
public : public :
......
#include "CNN.hpp"
#include "CNN.hpp"
CNNImpl::CNNImpl(std::vector<long> windowSizes, int nbFilters, int elementSize)
: windowSizes(windowSizes), nbFilters(nbFilters), elementSize(elementSize)
{
for (auto & windowSize : windowSizes)
{
std::string moduleName = fmt::format("cnn_window_{}", windowSize);
CNNs.emplace_back(register_module(moduleName, torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFilters, torch::ExpandingArray<2>({windowSize,elementSize})).padding({windowSize-1, 0}))));
}
}
torch::Tensor CNNImpl::forward(torch::Tensor input)
{
std::vector<torch::Tensor> windows;
for (unsigned int i = 0; i < CNNs.size(); i++)
{
auto convOut = torch::relu(CNNs[i](input).squeeze(-1));
auto pooled = torch::max_pool1d(convOut, convOut.size(2));
windows.emplace_back(pooled);
}
auto cnnOut = torch::cat(windows, 2);
cnnOut = cnnOut.view({cnnOut.size(0), -1});
return cnnOut;
}
int CNNImpl::getOutputSize()
{
return windowSizes.size()*nbFilters;
}
...@@ -13,13 +13,10 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i ...@@ -13,13 +13,10 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
setColumns(columns); setColumns(columns);
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
linear1 = register_module("linear1", torch::nn::Linear(nbFilters*windowSizes.size()+nbFiltersLetters*windowSizes.size()*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize)); contextCNN = register_module("contextCNN", CNN(std::vector<long>{2,3,4}, nbFilters, 2*embeddingsSize));
lettersCNN = register_module("lettersCNN", CNN(std::vector<long>{2,3,4,5}, nbFiltersLetters, embeddingsSize));
linear1 = register_module("linear1", torch::nn::Linear(contextCNN->getOutputSize()+lettersCNN->getOutputSize()*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
for (auto & windowSize : windowSizes)
{
CNNs.emplace_back(register_module(fmt::format("cnn_context_{}", windowSize), torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFilters, torch::ExpandingArray<2>({windowSize,2*embeddingsSize})).padding({windowSize-1, 0}))));
lettersCNNs.emplace_back(register_module(fmt::format("cnn_letters_{}", windowSize), torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFiltersLetters, torch::ExpandingArray<2>({windowSize,embeddingsSize})).padding({windowSize-1, 0}))));
}
} }
torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
...@@ -34,38 +31,16 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) ...@@ -34,38 +31,16 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
auto lettersEmbeddings = wordEmbeddings(wordLetters).view({wordLetters.size(0), wordLetters.size(1)/maxNbLetters, maxNbLetters, wordEmbeddings->options.embedding_dim()}).unsqueeze(1); auto lettersEmbeddings = wordEmbeddings(wordLetters).view({wordLetters.size(0), wordLetters.size(1)/maxNbLetters, maxNbLetters, wordEmbeddings->options.embedding_dim()}).unsqueeze(1);
auto permuted = lettersEmbeddings.permute({2,0,1,3,4}); auto permuted = lettersEmbeddings.permute({2,0,1,3,4});
std::vector<torch::Tensor> windows; std::vector<torch::Tensor> cnnOuts;
for (unsigned int word = 0; word < focusedBufferIndexes.size(); word++) for (unsigned int word = 0; word < focusedBufferIndexes.size(); word++)
for (unsigned int i = 0; i < lettersCNNs.size(); i++) cnnOuts.emplace_back(lettersCNN(permuted[word]));
{
auto input = permuted[word];
auto convOut = torch::relu(lettersCNNs[i](input).squeeze(-1));
auto pooled = torch::max_pool1d(convOut, convOut.size(2));
windows.emplace_back(pooled);
}
for (unsigned int word = 0; word < focusedStackIndexes.size(); word++) for (unsigned int word = 0; word < focusedStackIndexes.size(); word++)
for (unsigned int i = 0; i < lettersCNNs.size(); i++) cnnOuts.emplace_back(lettersCNN(permuted[word]));
{ auto lettersCnnOut = torch::cat(cnnOuts, 1);
auto input = permuted[focusedBufferIndexes.size()+word];
auto convOut = torch::relu(lettersCNNs[i](input).squeeze(-1));
auto pooled = torch::max_pool1d(convOut, convOut.size(2));
windows.emplace_back(pooled);
}
auto lettersCnnOut = torch::cat(windows, 2);
lettersCnnOut = lettersCnnOut.view({lettersCnnOut.size(0), -1});
windows.clear();
for (unsigned int i = 0; i < CNNs.size(); i++)
{
auto convOut = torch::relu(CNNs[i](embeddings).squeeze(-1));
auto pooled = torch::max_pool1d(convOut, convOut.size(2));
windows.emplace_back(pooled);
}
auto cnnOut = torch::cat(windows, 2); auto contextCnnOut = contextCNN(embeddings);
cnnOut = cnnOut.view({cnnOut.size(0), -1});
auto totalInput = torch::cat({cnnOut, lettersCnnOut}, 1); auto totalInput = torch::cat({contextCnnOut, lettersCnnOut}, 1);
return linear2(torch::relu(linear1(totalInput))); return linear2(torch::relu(linear1(totalInput)));
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment