#include "ConcatWordsNetwork.hpp" ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements) { constexpr int embeddingsSize = 100; setLeftBorder(leftBorder); setRightBorder(rightBorder); setNbStackElements(nbStackElements); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500)); linear2 = register_module("linear2", torch::nn::Linear(500, nbOutputs)); } torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input) { // input dim = {batch, sequence, embeddings} auto wordsAsEmb = wordEmbeddings(input); // reshaped dim = {batch, sequence of embeddings} auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)}); return linear2(torch::relu(linear1(reshaped))); }