Skip to content
Snippets Groups Projects
CNN.cpp 997 B
Newer Older
Franck Dary's avatar
Franck Dary committed
#include "CNN.hpp"
#include "CNN.hpp"

CNNImpl::CNNImpl(std::vector<long> windowSizes, int nbFilters, int elementSize)
  : windowSizes(windowSizes), nbFilters(nbFilters), elementSize(elementSize)
{
  for (auto & windowSize : windowSizes)
  {
    std::string moduleName = fmt::format("cnn_window_{}", windowSize);
    CNNs.emplace_back(register_module(moduleName, torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFilters, torch::ExpandingArray<2>({windowSize,elementSize})).padding({windowSize-1, 0}))));
  }
}

torch::Tensor CNNImpl::forward(torch::Tensor input)
{
  std::vector<torch::Tensor> windows;
  for (unsigned int i = 0; i < CNNs.size(); i++)
  {
    auto convOut = torch::relu(CNNs[i](input).squeeze(-1));
    auto pooled = torch::max_pool1d(convOut, convOut.size(2));
    windows.emplace_back(pooled);
  }

  auto cnnOut = torch::cat(windows, 2);
  cnnOut = cnnOut.view({cnnOut.size(0), -1});

  return cnnOut;
}

int CNNImpl::getOutputSize()
{
  return windowSizes.size()*nbFilters;
}