Skip to content
Snippets Groups Projects
Select Git revision
  • 81fdf35497fbce9dd113c8ad40798f5222f24231
  • master default protected
  • loss
  • producer
4 results

Concat.cpp

Blame
  • Concat.cpp 283 B
    #include "Concat.hpp"
    
    ConcatImpl::ConcatImpl(int inputSize) : inputSize(inputSize)
    {
    }
    
    torch::Tensor ConcatImpl::forward(torch::Tensor input)
    {
      return input.view({input.size(0), -1});
    }
    
    int ConcatImpl::getOutputSize(int sequenceLength)
    {
      return sequenceLength * inputSize;
    }