-
Franck Dary authored
Removed pretrainedEmbeddings as a global parameter, instead submodules can now have their own pretrained w2v
Franck Dary authoredRemoved pretrainedEmbeddings as a global parameter, instead submodules can now have their own pretrained w2v
RandomNetwork.cpp 921 B
#include "RandomNetwork.hpp"
RandomNetworkImpl::RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState) : nbOutputsPerState(nbOutputsPerState)
{
setName(name);
}
torch::Tensor RandomNetworkImpl::forward(torch::Tensor input)
{
if (input.dim() == 1)
input = input.unsqueeze(0);
return torch::randn({input.size(0), (long)nbOutputsPerState[getState()]}, torch::TensorOptions().device(device).requires_grad(true));
}
std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &)
{
return std::vector<std::vector<long>>{{0}};
}
void RandomNetworkImpl::registerEmbeddings()
{
}
void RandomNetworkImpl::saveDicts(std::filesystem::path)
{
}
void RandomNetworkImpl::loadDicts(std::filesystem::path)
{
}
void RandomNetworkImpl::setDictsState(Dict::State)
{
}
void RandomNetworkImpl::setCountOcc(bool)
{
}
void RandomNetworkImpl::removeRareDictElements(float)
{
}