Newer
Older
TestNetworkImpl::TestNetworkImpl(int nbOutputs, int focusedIndex)
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize).sparse(true)));
auto params = wordEmbeddings->parameters();
_sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end());
linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs));
params = linear->parameters();
_denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
std::vector<torch::Tensor> & TestNetworkImpl::denseParameters()
{
return _denseParameters;
}
std::vector<torch::Tensor> & TestNetworkImpl::sparseParameters()
{
return _sparseParameters;
}
torch::Tensor TestNetworkImpl::forward(torch::Tensor input)
// input dim = {batch, sequence, embeddings}
auto wordsAsEmb = wordEmbeddings(input);
if (reshaped.dim() == 3)
reshaped = wordsAsEmb.permute({1,0,2});
auto res = torch::softmax(linear(reshaped[focusedIndex]), reshaped.dim() == 3 ? 1 : 0);