#include "RandomNetwork.hpp" RandomNetworkImpl::RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState) : nbOutputsPerState(nbOutputsPerState) { setName(name); } torch::Tensor RandomNetworkImpl::forward(torch::Tensor input) { if (input.dim() == 1) input = input.unsqueeze(0); return torch::randn({input.size(0), (long)nbOutputsPerState[getState()]}, torch::TensorOptions().device(device).requires_grad(true)); } std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &) { return std::vector<std::vector<long>>{{0}}; } void RandomNetworkImpl::registerEmbeddings() { } void RandomNetworkImpl::saveDicts(std::filesystem::path) { } void RandomNetworkImpl::loadDicts(std::filesystem::path) { } void RandomNetworkImpl::setDictsState(Dict::State) { } void RandomNetworkImpl::setCountOcc(bool) { } void RandomNetworkImpl::removeRareDictElements(float) { }