Skip to content
Snippets Groups Projects
OneWordNetwork.cpp 706 B
Newer Older
#include "OneWordNetwork.hpp"
OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex)
  constexpr int embeddingsSize = 64;
  setBufferContext({focusedIndex});
  setStackContext({});
  setColumns({"FORM", "UPOS"});

  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
  linear = register_module("linear", torch::nn::Linear(getContextSize()*embeddingsSize, nbOutputs));
torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input)
  if (input.dim() == 1)
    input = input.unsqueeze(0);
  auto wordAsEmb = wordEmbeddings(input).view({input.size(0),-1});
  auto res = linear(wordAsEmb);
Franck Dary's avatar
Franck Dary committed
  return res;