-
Franck Dary authoredFranck Dary authored
ConcatWordsNetwork.cpp 1.04 KiB
#include "ConcatWordsNetwork.hpp"
ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
{
constexpr int embeddingsSize = 100;
setLeftBorder(leftBorder);
setRightBorder(rightBorder);
setNbStackElements(nbStackElements);
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500));
linear2 = register_module("linear2", torch::nn::Linear(500, nbOutputs));
}
torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
{
// input dim = {batch, sequence, embeddings}
auto wordsAsEmb = wordEmbeddings(input);
// reshaped dim = {batch, sequence of embeddings}
auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)});
return linear2(torch::relu(linear1(reshaped)));
}