#include "OneWordNetwork.hpp" OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex) { constexpr int embeddingsSize = 30; wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize).sparse(true))); linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs)); int leftBorder = 0; int rightBorder = 0; if (focusedIndex < 0) leftBorder = -focusedIndex; if (focusedIndex > 0) rightBorder = focusedIndex; this->focusedIndex = focusedIndex <= 0 ? 0 : focusedIndex; setLeftBorder(leftBorder); setRightBorder(rightBorder); setNbStackElements(0); } torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input) { // input dim = {batch, sequence, embeddings} auto wordsAsEmb = wordEmbeddings(input); auto reshaped = wordsAsEmb; // reshaped dim = {sequence, batch, embeddings} if (reshaped.dim() == 3) reshaped = wordsAsEmb.permute({1,0,2}); auto res = linear(reshaped[focusedIndex]); return res; }