-
Franck Dary authoredFranck Dary authored
OneWordNetwork.cpp 706 B
#include "OneWordNetwork.hpp"
OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex)
{
constexpr int embeddingsSize = 64;
setBufferContext({focusedIndex});
setStackContext({});
setColumns({"FORM", "UPOS"});
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
linear = register_module("linear", torch::nn::Linear(getContextSize()*embeddingsSize, nbOutputs));
}
torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input)
{
if (input.dim() == 1)
input = input.unsqueeze(0);
auto wordAsEmb = wordEmbeddings(input).view({input.size(0),-1});
auto res = linear(wordAsEmb);
return res;
}