Commit 2d940e47 authored by Franck Dary's avatar Franck Dary
Browse files

Added cnn

parent 567e4969
...@@ -2,22 +2,21 @@ ...@@ -2,22 +2,21 @@
#define CNN__H #define CNN__H
#include <torch/torch.h> #include <torch/torch.h>
#include "MyModule.hpp"
class CNNImpl : public torch::nn::Module class CNNImpl : public MyModule
{ {
private : private :
std::vector<int> windowSizes;
std::vector<torch::nn::Conv2d> CNNs; std::vector<torch::nn::Conv2d> CNNs;
int nbFilters; std::vector<int> windowSizes{2, 3};
int elementSize; int outputSize;
public : public :
CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize); CNNImpl(int inputSize, int outputSize, ModuleOptions options);
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize(); int getOutputSize(int sequenceLength);
}; };
TORCH_MODULE(CNN); TORCH_MODULE(CNN);
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "MyModule.hpp" #include "MyModule.hpp"
#include "LSTM.hpp" #include "LSTM.hpp"
#include "GRU.hpp" #include "GRU.hpp"
#include "CNN.hpp"
#include "Concat.hpp" #include "Concat.hpp"
class HistoryModuleImpl : public Submodule class HistoryModuleImpl : public Submodule
......
#include "CNN.hpp" #include "CNN.hpp"
#include "fmt/core.h" #include "fmt/core.h"
CNNImpl::CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize) CNNImpl::CNNImpl(int inputSize, int outputSize, ModuleOptions options) : outputSize(outputSize)
: windowSizes(windowSizes), nbFilters(nbFilters), elementSize(elementSize)
{ {
for (auto & windowSize : windowSizes) for (auto & windowSize : windowSizes)
{ {
std::string moduleName = fmt::format("cnn_window_{}", windowSize); std::string moduleName = fmt::format("cnn_window_{}", windowSize);
CNNs.emplace_back(register_module(moduleName, torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFilters, torch::ExpandingArray<2>({windowSize,elementSize})).padding({windowSize-1, 0})))); auto kernel = torch::ExpandingArray<2>({windowSize, inputSize});
auto opts = torch::nn::Conv2dOptions(1, outputSize, kernel).padding({windowSize-1, 0});
CNNs.emplace_back(register_module(moduleName, torch::nn::Conv2d(opts)));
} }
} }
torch::Tensor CNNImpl::forward(torch::Tensor input) torch::Tensor CNNImpl::forward(torch::Tensor input)
{ {
std::vector<torch::Tensor> windows; std::vector<torch::Tensor> windows;
input = input.unsqueeze(1);
for (unsigned int i = 0; i < CNNs.size(); i++) for (unsigned int i = 0; i < CNNs.size(); i++)
{ {
auto convOut = torch::relu(CNNs[i](input).squeeze(-1)); auto convOut = CNNs[i](input).squeeze(-1);
auto pooled = torch::max_pool1d(convOut, convOut.size(2)); auto pooled = torch::max_pool1d(convOut, convOut.size(-1));
windows.emplace_back(pooled); windows.emplace_back(pooled);
} }
auto cnnOut = torch::cat(windows, 2); return torch::cat(windows, -1).view({input.size(0), -1});
cnnOut = cnnOut.view({cnnOut.size(0), -1});
return cnnOut;
} }
std::size_t CNNImpl::getOutputSize() int CNNImpl::getOutputSize(int)
{ {
return windowSizes.size()*nbFilters; return outputSize*windowSizes.size();
} }
...@@ -26,6 +26,8 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin ...@@ -26,6 +26,8 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin
myModule = register_module("myModule", LSTM(inSize, outSize, options)); myModule = register_module("myModule", LSTM(inSize, outSize, options));
else if (subModuleType == "GRU") else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options)); myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "CNN")
myModule = register_module("myModule", CNN(inSize, outSize, options));
else if (subModuleType == "Concat") else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize)); myModule = register_module("myModule", Concat(inSize));
else else
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment