#include "RTLSTMNetwork.hpp" RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements) { constexpr int embeddingsSize = 30; constexpr int lstmOutputSize = 500; constexpr int hiddenSize = 500; setLeftBorder(leftBorder); setRightBorder(rightBorder); setNbStackElements(nbStackElements); setColumns({"FORM", "UPOS"}); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); linear1 = register_module("linear1", torch::nn::Linear(lstmOutputSize, hiddenSize)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); dropout = register_module("dropout", torch::nn::Dropout(0.3)); lstm = register_module("lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, lstmOutputSize).batch_first(true))); } torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input) { // input dim = {batch, sequence, embeddings} auto wordsAsEmb = wordEmbeddings(input); if (wordsAsEmb.dim() == 2) wordsAsEmb = torch::unsqueeze(wordsAsEmb, 0); auto lstmOut = lstm(wordsAsEmb).output; // reshaped dim = {sequence, batch, embeddings} auto reshaped = lstmOut.permute({1,0,2}); auto res = linear2(torch::relu(linear1(reshaped[-1]))); return res; }