Skip to content
Snippets Groups Projects
Concat.cpp 283 B
#include "Concat.hpp"

ConcatImpl::ConcatImpl(int inputSize) : inputSize(inputSize)
{
}

torch::Tensor ConcatImpl::forward(torch::Tensor input)
{
  return input.view({input.size(0), -1});
}

int ConcatImpl::getOutputSize(int sequenceLength)
{
  return sequenceLength * inputSize;
}