"...git@gitlab.lis-lab.fr:dev/scikit-multimodallearn.git" did not exist on "96d4fef67f15062b7a1a307f9e5fe57509f66853"
Newer
Older
RandomNetworkImpl::RandomNetworkImpl(std::map<std::string,std::size_t> nbOutputsPerState) : nbOutputsPerState(nbOutputsPerState)
{
}
torch::Tensor RandomNetworkImpl::forward(torch::Tensor input)
{
if (input.dim() == 1)
input = input.unsqueeze(0);
return torch::randn({input.size(0), (long)nbOutputsPerState[getState()]}, torch::TensorOptions().device(device).requires_grad(true));
std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &, Dict &) const
{