Commit 2d940e47 authored by Franck Dary's avatar Franck Dary
Browse files

Added cnn

parent 567e4969
......@@ -2,22 +2,21 @@
#define CNN__H
#include <torch/torch.h>
#include "MyModule.hpp"
class CNNImpl : public torch::nn::Module
class CNNImpl : public MyModule
{
private :
std::vector<int> windowSizes;
std::vector<torch::nn::Conv2d> CNNs;
int nbFilters;
int elementSize;
std::vector<int> windowSizes{2, 3};
int outputSize;
public :
CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize);
CNNImpl(int inputSize, int outputSize, ModuleOptions options);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize();
int getOutputSize(int sequenceLength);
};
TORCH_MODULE(CNN);
......
......@@ -6,6 +6,7 @@
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
#include "CNN.hpp"
#include "Concat.hpp"
class HistoryModuleImpl : public Submodule
......
#include "CNN.hpp"
#include "fmt/core.h"
CNNImpl::CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize)
: windowSizes(windowSizes), nbFilters(nbFilters), elementSize(elementSize)
CNNImpl::CNNImpl(int inputSize, int outputSize, ModuleOptions options) : outputSize(outputSize)
{
for (auto & windowSize : windowSizes)
{
std::string moduleName = fmt::format("cnn_window_{}", windowSize);
CNNs.emplace_back(register_module(moduleName, torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFilters, torch::ExpandingArray<2>({windowSize,elementSize})).padding({windowSize-1, 0}))));
auto kernel = torch::ExpandingArray<2>({windowSize, inputSize});
auto opts = torch::nn::Conv2dOptions(1, outputSize, kernel).padding({windowSize-1, 0});
CNNs.emplace_back(register_module(moduleName, torch::nn::Conv2d(opts)));
}
}
torch::Tensor CNNImpl::forward(torch::Tensor input)
{
std::vector<torch::Tensor> windows;
input = input.unsqueeze(1);
for (unsigned int i = 0; i < CNNs.size(); i++)
{
auto convOut = torch::relu(CNNs[i](input).squeeze(-1));
auto pooled = torch::max_pool1d(convOut, convOut.size(2));
auto convOut = CNNs[i](input).squeeze(-1);
auto pooled = torch::max_pool1d(convOut, convOut.size(-1));
windows.emplace_back(pooled);
}
auto cnnOut = torch::cat(windows, 2);
cnnOut = cnnOut.view({cnnOut.size(0), -1});
return cnnOut;
return torch::cat(windows, -1).view({input.size(0), -1});
}
std::size_t CNNImpl::getOutputSize()
int CNNImpl::getOutputSize(int)
{
return windowSizes.size()*nbFilters;
return outputSize*windowSizes.size();
}
......@@ -26,6 +26,8 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin
myModule = register_module("myModule", LSTM(inSize, outSize, options));
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "CNN")
myModule = register_module("myModule", CNN(inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
else
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment