Newer
Older
OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex)
constexpr int embeddingsSize = 64;
setBufferContext({focusedIndex});
setStackContext({});
setColumns({"FORM", "UPOS"});
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
linear = register_module("linear", torch::nn::Linear(getContextSize()*embeddingsSize, nbOutputs));
torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input)
if (input.dim() == 1)
input = input.unsqueeze(0);
auto wordAsEmb = wordEmbeddings(input).view({input.size(0),-1});
auto res = linear(wordAsEmb);