Skip to content
Snippets Groups Projects
Select Git revision
  • 0c86cb53315b168fc479310660cce1b8d7072b5b
  • master default protected
  • loss
  • producer
4 results

RandomNetwork.cpp

Blame
  • RandomNetwork.cpp 943 B
    #include "RandomNetwork.hpp"
    
    RandomNetworkImpl::RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState) : nbOutputsPerState(nbOutputsPerState)
    {
      setName(name);
    }
    
    torch::Tensor RandomNetworkImpl::forward(torch::Tensor input, const std::string & state)
    {
      if (input.dim() == 1)
        input = input.unsqueeze(0);
    
      return torch::randn({input.size(0), (long)nbOutputsPerState[state]}, torch::TensorOptions().device(device).requires_grad(true));
    }
    
    std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &)
    {
      return std::vector<std::vector<long>>{{0}};
    }
    
    void RandomNetworkImpl::registerEmbeddings()
    {
    }
    
    void RandomNetworkImpl::saveDicts(std::filesystem::path)
    {
    }
    
    void RandomNetworkImpl::loadDicts(std::filesystem::path)
    {
    }
    
    void RandomNetworkImpl::setDictsState(Dict::State)
    {
    }
    
    void RandomNetworkImpl::setCountOcc(bool)
    {
    }
    
    void RandomNetworkImpl::removeRareDictElements(float)
    {
    }