-
Franck Dary authoredFranck Dary authored
OneWordNetwork.cpp 1.01 KiB
#include "OneWordNetwork.hpp"
OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex)
{
constexpr int embeddingsSize = 30;
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize)));
linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs));
int leftBorder = 0;
int rightBorder = 0;
if (focusedIndex < 0)
leftBorder = -focusedIndex;
if (focusedIndex > 0)
rightBorder = focusedIndex;
this->focusedIndex = focusedIndex <= 0 ? 0 : focusedIndex;
setLeftBorder(leftBorder);
setRightBorder(rightBorder);
setNbStackElements(0);
}
torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input)
{
// input dim = {batch, sequence, embeddings}
auto wordsAsEmb = wordEmbeddings(input);
auto reshaped = wordsAsEmb;
// reshaped dim = {sequence, batch, embeddings}
if (reshaped.dim() == 3)
reshaped = wordsAsEmb.permute({1,0,2});
auto res = linear(reshaped[focusedIndex]);
return res;
}