#include "CNN.hpp" #include "CNN.hpp" CNNImpl::CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize) : windowSizes(windowSizes), nbFilters(nbFilters), elementSize(elementSize) { for (auto & windowSize : windowSizes) { std::string moduleName = fmt::format("cnn_window_{}", windowSize); CNNs.emplace_back(register_module(moduleName, torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFilters, torch::ExpandingArray<2>({windowSize,elementSize})).padding({windowSize-1, 0})))); } } torch::Tensor CNNImpl::forward(torch::Tensor input) { std::vector<torch::Tensor> windows; for (unsigned int i = 0; i < CNNs.size(); i++) { auto convOut = torch::relu(CNNs[i](input).squeeze(-1)); auto pooled = torch::max_pool1d(convOut, convOut.size(2)); windows.emplace_back(pooled); } auto cnnOut = torch::cat(windows, 2); cnnOut = cnnOut.view({cnnOut.size(0), -1}); return cnnOut; } int CNNImpl::getOutputSize() { return windowSizes.size()*nbFilters; }