#include "ConcatWordsNetwork.hpp" ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, const std::vector<int> & bufferContext, const std::vector<int> & stackContext) { constexpr int embeddingsSize = 64; constexpr int hiddenSize = 500; setBufferContext(bufferContext); setStackContext(stackContext); setColumns({"FORM", "UPOS"}); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, hiddenSize)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); dropout = register_module("dropout", torch::nn::Dropout(0.3)); } torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input) { if (input.dim() == 1) input = input.unsqueeze(0); auto wordsAsEmb = dropout(wordEmbeddings(input).view({input.size(0), -1})); return linear2(torch::relu(linear1(wordsAsEmb))); }