Newer
Older
#include "RTLSTMNetwork.hpp"
RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
{
constexpr int embeddingsSize = 30;
constexpr int lstmOutputSize = 500;
constexpr int hiddenSize = 500;
setLeftBorder(leftBorder);
setRightBorder(rightBorder);
setNbStackElements(nbStackElements);
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
linear1 = register_module("linear1", torch::nn::Linear(lstmOutputSize, hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
dropout = register_module("dropout", torch::nn::Dropout(0.3));
lstm = register_module("lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, lstmOutputSize).batch_first(true)));
}
torch::Tensor RTLSTMNetworkImpl::forward(torch::Tensor input)
{
// input dim = {batch, sequence, embeddings}
auto wordsAsEmb = wordEmbeddings(input);
if (wordsAsEmb.dim() == 2)
wordsAsEmb = torch::unsqueeze(wordsAsEmb, 0);
auto lstmOut = lstm(wordsAsEmb).output;
// reshaped dim = {sequence, batch, embeddings}
auto reshaped = lstmOut.permute({1,0,2});
auto res = linear2(torch::relu(linear1(reshaped[-1])));
return res;
}