Skip to content
Snippets Groups Projects
Commit 2d940e47 authored by Franck Dary's avatar Franck Dary
Browse files

Added cnn

parent 567e4969
No related branches found
No related tags found
No related merge requests found
......@@ -2,22 +2,21 @@
#define CNN__H
#include <torch/torch.h>
#include "MyModule.hpp"
class CNNImpl : public torch::nn::Module
class CNNImpl : public MyModule
{
private :
std::vector<int> windowSizes;
std::vector<torch::nn::Conv2d> CNNs;
int nbFilters;
int elementSize;
std::vector<int> windowSizes{2, 3};
int outputSize;
public :
CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize);
CNNImpl(int inputSize, int outputSize, ModuleOptions options);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize();
int getOutputSize(int sequenceLength);
};
TORCH_MODULE(CNN);
......
......@@ -6,6 +6,7 @@
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
#include "CNN.hpp"
#include "Concat.hpp"
class HistoryModuleImpl : public Submodule
......
#include "CNN.hpp"
#include "fmt/core.h"
CNNImpl::CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize)
: windowSizes(windowSizes), nbFilters(nbFilters), elementSize(elementSize)
CNNImpl::CNNImpl(int inputSize, int outputSize, ModuleOptions options) : outputSize(outputSize)
{
for (auto & windowSize : windowSizes)
{
std::string moduleName = fmt::format("cnn_window_{}", windowSize);
CNNs.emplace_back(register_module(moduleName, torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFilters, torch::ExpandingArray<2>({windowSize,elementSize})).padding({windowSize-1, 0}))));
auto kernel = torch::ExpandingArray<2>({windowSize, inputSize});
auto opts = torch::nn::Conv2dOptions(1, outputSize, kernel).padding({windowSize-1, 0});
CNNs.emplace_back(register_module(moduleName, torch::nn::Conv2d(opts)));
}
}
torch::Tensor CNNImpl::forward(torch::Tensor input)
{
std::vector<torch::Tensor> windows;
input = input.unsqueeze(1);
for (unsigned int i = 0; i < CNNs.size(); i++)
{
auto convOut = torch::relu(CNNs[i](input).squeeze(-1));
auto pooled = torch::max_pool1d(convOut, convOut.size(2));
auto convOut = CNNs[i](input).squeeze(-1);
auto pooled = torch::max_pool1d(convOut, convOut.size(-1));
windows.emplace_back(pooled);
}
auto cnnOut = torch::cat(windows, 2);
cnnOut = cnnOut.view({cnnOut.size(0), -1});
return cnnOut;
return torch::cat(windows, -1).view({input.size(0), -1});
}
std::size_t CNNImpl::getOutputSize()
int CNNImpl::getOutputSize(int)
{
return windowSizes.size()*nbFilters;
return outputSize*windowSizes.size();
}
......@@ -26,6 +26,8 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin
myModule = register_module("myModule", LSTM(inSize, outSize, options));
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "CNN")
myModule = register_module("myModule", CNN(inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
else
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment