Skip to content
Snippets Groups Projects
TestNetwork.cpp 1.22 KiB
Newer Older
#include "TestNetwork.hpp"

Franck Dary's avatar
Franck Dary committed
TestNetworkImpl::TestNetworkImpl(int nbOutputs, int focusedIndex)
Franck Dary's avatar
Franck Dary committed
  constexpr int embeddingsSize = 30;

  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize).sparse(true)));
  auto params = wordEmbeddings->parameters();
  _sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end());

Franck Dary's avatar
Franck Dary committed
  linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs));
  params = linear->parameters();
  _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());

Franck Dary's avatar
Franck Dary committed
  this->focusedIndex = focusedIndex;
std::vector<torch::Tensor> & TestNetworkImpl::denseParameters()
{
  return _denseParameters;
}

std::vector<torch::Tensor> & TestNetworkImpl::sparseParameters()
{
  return _sparseParameters;
}

Franck Dary's avatar
Franck Dary committed
torch::Tensor TestNetworkImpl::forward(torch::Tensor input)
Franck Dary's avatar
Franck Dary committed
  // input dim = {batch, sequence, embeddings}
  auto wordsAsEmb = wordEmbeddings(input);
Franck Dary's avatar
Franck Dary committed
  auto reshaped = wordsAsEmb;
Franck Dary's avatar
Franck Dary committed
  // reshaped dim = {sequence, batch, embeddings}
Franck Dary's avatar
Franck Dary committed
  if (reshaped.dim() == 3)
    reshaped = wordsAsEmb.permute({1,0,2});
Franck Dary's avatar
Franck Dary committed
  auto res = torch::softmax(linear(reshaped[focusedIndex]), reshaped.dim() == 3 ? 1 : 0);
Franck Dary's avatar
Franck Dary committed
  return res;