-
Franck Dary authoredFranck Dary authored
ConcatWordsNetwork.cpp 992 B
#include "ConcatWordsNetwork.hpp"
ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, const std::vector<int> & bufferContext, const std::vector<int> & stackContext)
{
constexpr int embeddingsSize = 64;
constexpr int hiddenSize = 500;
setBufferContext(bufferContext);
setStackContext(stackContext);
setColumns({"FORM", "UPOS"});
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
dropout = register_module("dropout", torch::nn::Dropout(0.3));
}
torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
{
if (input.dim() == 1)
input = input.unsqueeze(0);
auto wordsAsEmb = dropout(wordEmbeddings(input).view({input.size(0), -1}));
return linear2(torch::relu(linear1(wordsAsEmb)));
}