Skip to content
Snippets Groups Projects
RandomNetwork.cpp 468 B
Newer Older
Franck Dary's avatar
Franck Dary committed
#include "RandomNetwork.hpp"

RandomNetworkImpl::RandomNetworkImpl(long outputSize) : outputSize(outputSize)
{
  setBufferContext({0});
  setStackContext({});
  setBufferFocused({});
  setStackFocused({});
  setColumns({"FORM"});
}

torch::Tensor RandomNetworkImpl::forward(torch::Tensor input)
{
  if (input.dim() == 1)
    input = input.unsqueeze(0);

  return torch::randn({input.size(0), outputSize}, torch::TensorOptions().device(device).requires_grad(true));
}