#include "TestNetwork.hpp"

TestNetworkImpl::TestNetworkImpl(int nbOutputs, int focusedIndex)
{
  constexpr int embeddingsSize = 30;
  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, embeddingsSize));
  linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs));
  this->focusedIndex = focusedIndex;
}

torch::Tensor TestNetworkImpl::forward(torch::Tensor input)
{
  // input dim = {batch, sequence, embeddings}
  auto wordsAsEmb = wordEmbeddings(input);
  // reshaped dim = {sequence, batch, embeddings}
  auto reshaped = wordsAsEmb.permute({1,0,2});

  auto res = torch::softmax(linear(reshaped[focusedIndex]), 1);

  return res;
}