Skip to content
Snippets Groups Projects
ConcatWordsNetwork.cpp 992 B
#include "ConcatWordsNetwork.hpp"

ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, const std::vector<int> & bufferContext, const std::vector<int> & stackContext)
{
  constexpr int embeddingsSize = 64;
  constexpr int hiddenSize = 500;

  setBufferContext(bufferContext);
  setStackContext(stackContext);
  setColumns({"FORM", "UPOS"});

  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
  linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, hiddenSize));
  linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
  dropout = register_module("dropout", torch::nn::Dropout(0.3));
}

torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
{
  if (input.dim() == 1)
    input = input.unsqueeze(0);
  auto wordsAsEmb = dropout(wordEmbeddings(input).view({input.size(0), -1}));
  return linear2(torch::relu(linear1(wordsAsEmb)));
}