Skip to content
Snippets Groups Projects
CNN.cpp 996 B
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #include "CNN.hpp"
    #include "CNN.hpp"
    
    
    CNNImpl::CNNImpl(std::vector<int> windowSizes, int nbFilters, int elementSize)
    
    Franck Dary's avatar
    Franck Dary committed
      : windowSizes(windowSizes), nbFilters(nbFilters), elementSize(elementSize)
    {
      for (auto & windowSize : windowSizes)
      {
        std::string moduleName = fmt::format("cnn_window_{}", windowSize);
        CNNs.emplace_back(register_module(moduleName, torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFilters, torch::ExpandingArray<2>({windowSize,elementSize})).padding({windowSize-1, 0}))));
      }
    }
    
    torch::Tensor CNNImpl::forward(torch::Tensor input)
    {
      std::vector<torch::Tensor> windows;
      for (unsigned int i = 0; i < CNNs.size(); i++)
      {
        auto convOut = torch::relu(CNNs[i](input).squeeze(-1));
        auto pooled = torch::max_pool1d(convOut, convOut.size(2));
        windows.emplace_back(pooled);
      }
    
      auto cnnOut = torch::cat(windows, 2);
      cnnOut = cnnOut.view({cnnOut.size(0), -1});
    
      return cnnOut;
    }
    
    int CNNImpl::getOutputSize()
    {
      return windowSizes.size()*nbFilters;
    }