#include "OneWordNetwork.hpp" OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex) { constexpr int embeddingsSize = 64; setBufferContext({focusedIndex}); setStackContext({}); setColumns({"FORM", "UPOS"}); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); linear = register_module("linear", torch::nn::Linear(getContextSize()*embeddingsSize, nbOutputs)); } torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input) { if (input.dim() == 1) input = input.unsqueeze(0); auto wordAsEmb = wordEmbeddings(input).view({input.size(0),-1}); auto res = linear(wordAsEmb); return res; }