#include "TestNetwork.hpp" TestNetworkImpl::TestNetworkImpl(int nbOutputs, int focusedIndex) { constexpr int embeddingsSize = 30; wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize).sparse(true))); auto params = wordEmbeddings->parameters(); _sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end()); linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs)); params = linear->parameters(); _denseParameters.insert(_denseParameters.end(), params.begin(), params.end()); this->focusedIndex = focusedIndex; } std::vector<torch::Tensor> & TestNetworkImpl::denseParameters() { return _denseParameters; } std::vector<torch::Tensor> & TestNetworkImpl::sparseParameters() { return _sparseParameters; } torch::Tensor TestNetworkImpl::forward(torch::Tensor input) { // input dim = {batch, sequence, embeddings} auto wordsAsEmb = wordEmbeddings(input); // reshaped dim = {sequence, batch, embeddings} auto reshaped = wordsAsEmb.permute({1,0,2}); auto res = torch::softmax(linear(reshaped[focusedIndex]), 1); return res; }