Skip to content
Snippets Groups Projects
Concat.cpp 283 B
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #include "Concat.hpp"
    
    ConcatImpl::ConcatImpl(int inputSize) : inputSize(inputSize)
    {
    }
    
    torch::Tensor ConcatImpl::forward(torch::Tensor input)
    {
      return input.view({input.size(0), -1});
    }
    
    int ConcatImpl::getOutputSize(int sequenceLength)
    {
      return sequenceLength * inputSize;
    }