-
Franck Dary authored
Removed distinction between dense and sparse parameters because it was hurting performances and the advantage in speed was not significant
Franck Dary authoredRemoved distinction between dense and sparse parameters because it was hurting performances and the advantage in speed was not significant
ConcatWordsNetwork.cpp 1.05 KiB
#include "ConcatWordsNetwork.hpp"
ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
{
constexpr int embeddingsSize = 100;
setLeftBorder(leftBorder);
setRightBorder(rightBorder);
setNbStackElements(nbStackElements);
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize).sparse(true)));
linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500));
linear2 = register_module("linear2", torch::nn::Linear(500, nbOutputs));
}
torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
{
// input dim = {batch, sequence, embeddings}
auto wordsAsEmb = wordEmbeddings(input);
// reshaped dim = {batch, sequence of embeddings}
auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)});
return linear2(torch::relu(linear1(reshaped)));
}