Skip to content
Snippets Groups Projects
RandomNetwork.cpp 564 B
Newer Older
Franck Dary's avatar
Franck Dary committed
#include "RandomNetwork.hpp"

RandomNetworkImpl::RandomNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState) : nbOutputsPerState(nbOutputsPerState)
Franck Dary's avatar
Franck Dary committed
{
}

torch::Tensor RandomNetworkImpl::forward(torch::Tensor input)
{
  if (input.dim() == 1)
    input = input.unsqueeze(0);

  return torch::randn({input.size(0), (long)nbOutputsPerState[getState()]}, torch::TensorOptions().device(device).requires_grad(true));
Franck Dary's avatar
Franck Dary committed
std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &, Dict &) const
{
Franck Dary's avatar
Franck Dary committed
  return std::vector<std::vector<long>>{{0}};
Franck Dary's avatar
Franck Dary committed
}