Skip to content
Snippets Groups Projects
Commit 2d940e47 authored by Franck Dary's avatar Franck Dary
Browse files

Added cnn

parent 567e4969
No related branches found
No related tags found
No related merge requests found
...@@ -2,22 +2,21 @@ ...@@ -2,22 +2,21 @@
#define CNN__H #define CNN__H
#include <torch/torch.h> #include <torch/torch.h>
#include "MyModule.hpp"
class CNNImpl : public torch::nn::Module class CNNImpl : public MyModule
{ {
private : private :
std::vector<int> windowSizes;
std::vector<torch::nn::Conv2d> CNNs; std::vector<torch::nn::Conv2d> CNNs;
int nbFilters; std::vector<int> windowSizes{2, 3};
int elementSize; int outputSize;
public : public :
CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize); CNNImpl(int inputSize, int outputSize, ModuleOptions options);
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize(); int getOutputSize(int sequenceLength);
}; };
TORCH_MODULE(CNN); TORCH_MODULE(CNN);
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "MyModule.hpp" #include "MyModule.hpp"
#include "LSTM.hpp" #include "LSTM.hpp"
#include "GRU.hpp" #include "GRU.hpp"
#include "CNN.hpp"
#include "Concat.hpp" #include "Concat.hpp"
class HistoryModuleImpl : public Submodule class HistoryModuleImpl : public Submodule
......
#include "CNN.hpp" #include "CNN.hpp"
#include "fmt/core.h" #include "fmt/core.h"
CNNImpl::CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize) CNNImpl::CNNImpl(int inputSize, int outputSize, ModuleOptions options) : outputSize(outputSize)
: windowSizes(windowSizes), nbFilters(nbFilters), elementSize(elementSize)
{ {
for (auto & windowSize : windowSizes) for (auto & windowSize : windowSizes)
{ {
std::string moduleName = fmt::format("cnn_window_{}", windowSize); std::string moduleName = fmt::format("cnn_window_{}", windowSize);
CNNs.emplace_back(register_module(moduleName, torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFilters, torch::ExpandingArray<2>({windowSize,elementSize})).padding({windowSize-1, 0})))); auto kernel = torch::ExpandingArray<2>({windowSize, inputSize});
auto opts = torch::nn::Conv2dOptions(1, outputSize, kernel).padding({windowSize-1, 0});
CNNs.emplace_back(register_module(moduleName, torch::nn::Conv2d(opts)));
} }
} }
torch::Tensor CNNImpl::forward(torch::Tensor input) torch::Tensor CNNImpl::forward(torch::Tensor input)
{ {
std::vector<torch::Tensor> windows; std::vector<torch::Tensor> windows;
input = input.unsqueeze(1);
for (unsigned int i = 0; i < CNNs.size(); i++) for (unsigned int i = 0; i < CNNs.size(); i++)
{ {
auto convOut = torch::relu(CNNs[i](input).squeeze(-1)); auto convOut = CNNs[i](input).squeeze(-1);
auto pooled = torch::max_pool1d(convOut, convOut.size(2)); auto pooled = torch::max_pool1d(convOut, convOut.size(-1));
windows.emplace_back(pooled); windows.emplace_back(pooled);
} }
auto cnnOut = torch::cat(windows, 2); return torch::cat(windows, -1).view({input.size(0), -1});
cnnOut = cnnOut.view({cnnOut.size(0), -1});
return cnnOut;
} }
std::size_t CNNImpl::getOutputSize() int CNNImpl::getOutputSize(int)
{ {
return windowSizes.size()*nbFilters; return outputSize*windowSizes.size();
} }
...@@ -26,6 +26,8 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin ...@@ -26,6 +26,8 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin
myModule = register_module("myModule", LSTM(inSize, outSize, options)); myModule = register_module("myModule", LSTM(inSize, outSize, options));
else if (subModuleType == "GRU") else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options)); myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "CNN")
myModule = register_module("myModule", CNN(inSize, outSize, options));
else if (subModuleType == "Concat") else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize)); myModule = register_module("myModule", Concat(inSize));
else else
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment