Skip to content
Snippets Groups Projects
OneWordNetwork.cpp 706 B
#include "OneWordNetwork.hpp"

OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex)
{
  constexpr int embeddingsSize = 64;

  setBufferContext({focusedIndex});
  setStackContext({});
  setColumns({"FORM", "UPOS"});

  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
  linear = register_module("linear", torch::nn::Linear(getContextSize()*embeddingsSize, nbOutputs));
}

torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input)
{
  if (input.dim() == 1)
    input = input.unsqueeze(0);
  auto wordAsEmb = wordEmbeddings(input).view({input.size(0),-1});
  auto res = linear(wordAsEmb);

  return res;
}