Newer
Older
#include "CNN.hpp"
#include "CNN.hpp"
CNNImpl::CNNImpl(std::vector<long> windowSizes, int nbFilters, int elementSize)
: windowSizes(windowSizes), nbFilters(nbFilters), elementSize(elementSize)
{
for (auto & windowSize : windowSizes)
{
std::string moduleName = fmt::format("cnn_window_{}", windowSize);
CNNs.emplace_back(register_module(moduleName, torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFilters, torch::ExpandingArray<2>({windowSize,elementSize})).padding({windowSize-1, 0}))));
}
}
torch::Tensor CNNImpl::forward(torch::Tensor input)
{
std::vector<torch::Tensor> windows;
for (unsigned int i = 0; i < CNNs.size(); i++)
{
auto convOut = torch::relu(CNNs[i](input).squeeze(-1));
auto pooled = torch::max_pool1d(convOut, convOut.size(2));
windows.emplace_back(pooled);
}
auto cnnOut = torch::cat(windows, 2);
cnnOut = cnnOut.view({cnnOut.size(0), -1});
return cnnOut;
}
int CNNImpl::getOutputSize()
{
return windowSizes.size()*nbFilters;
}