Skip to content
Snippets Groups Projects
RTLSTMNetwork.cpp 1.28 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#include "RTLSTMNetwork.hpp"

RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
{
  constexpr int embeddingsSize = 30;
Franck Dary's avatar
Franck Dary committed
  constexpr int lstmOutputSize = 500;
  constexpr int hiddenSize = 500;
  setLeftBorder(leftBorder);
  setRightBorder(rightBorder);
  setNbStackElements(nbStackElements);
  setColumns({"FORM", "UPOS"});
Franck Dary's avatar
Franck Dary committed

  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
  linear1 = register_module("linear1", torch::nn::Linear(lstmOutputSize, hiddenSize));
  linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
  dropout = register_module("dropout", torch::nn::Dropout(0.3));
  lstm = register_module("lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, lstmOutputSize).batch_first(true)));
}

torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input)
{
  // input dim = {batch, sequence, embeddings}
  auto wordsAsEmb = wordEmbeddings(input);
  if (wordsAsEmb.dim() == 2)
    wordsAsEmb = torch::unsqueeze(wordsAsEmb, 0);
  auto lstmOut = lstm(wordsAsEmb).output;
  // reshaped dim = {sequence, batch, embeddings}
  auto reshaped = lstmOut.permute({1,0,2});
  auto res = linear2(torch::relu(linear1(reshaped[-1])));

  return res;
}