Newer
Older
TestNetworkImpl::TestNetworkImpl(int nbOutputs, int focusedIndex)
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, embeddingsSize));
linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs));
this->focusedIndex = focusedIndex;
torch::Tensor TestNetworkImpl::forward(torch::Tensor input)
// input dim = {batch, sequence, embeddings}
auto wordsAsEmb = wordEmbeddings(input);
// reshaped dim = {sequence, batch, embeddings}
auto reshaped = wordsAsEmb.permute({1,0,2});
auto res = torch::softmax(linear(reshaped[focusedIndex]), 1);