Skip to content
Snippets Groups Projects
RandomNetwork.cpp 468 B
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #include "RandomNetwork.hpp"
    
    RandomNetworkImpl::RandomNetworkImpl(long outputSize) : outputSize(outputSize)
    {
      setBufferContext({0});
      setStackContext({});
      setBufferFocused({});
      setStackFocused({});
      setColumns({"FORM"});
    }
    
    torch::Tensor RandomNetworkImpl::forward(torch::Tensor input)
    {
      if (input.dim() == 1)
        input = input.unsqueeze(0);
    
      return torch::randn({input.size(0), outputSize}, torch::TensorOptions().device(device).requires_grad(true));
    }