Skip to content
Snippets Groups Projects
OneWordNetwork.cpp 1.04 KiB
Newer Older
#include "OneWordNetwork.hpp"
OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex)
Franck Dary's avatar
Franck Dary committed
  constexpr int embeddingsSize = 30;
Franck Dary's avatar
Franck Dary committed
  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize)));
Franck Dary's avatar
Franck Dary committed
  linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs));
Franck Dary's avatar
Franck Dary committed
  int leftBorder = 0;
  int rightBorder = 0;
  if (focusedIndex < 0)
    leftBorder = -focusedIndex;
  if (focusedIndex > 0)
    rightBorder = focusedIndex;

  this->focusedIndex = focusedIndex <= 0 ? 0 : focusedIndex;

  setLeftBorder(leftBorder);
  setRightBorder(rightBorder);
  setNbStackElements(0);
  setColumns({"FORM", "UPOS"});
torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input)
Franck Dary's avatar
Franck Dary committed
  // input dim = {batch, sequence, embeddings}
  auto wordsAsEmb = wordEmbeddings(input);
Franck Dary's avatar
Franck Dary committed
  auto reshaped = wordsAsEmb;
Franck Dary's avatar
Franck Dary committed
  // reshaped dim = {sequence, batch, embeddings}
Franck Dary's avatar
Franck Dary committed
  if (reshaped.dim() == 3)
    reshaped = wordsAsEmb.permute({1,0,2});
  auto res = linear(reshaped[focusedIndex]);
Franck Dary's avatar
Franck Dary committed
  return res;