#include "ConcatWordsNetwork.hpp" ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs) { constexpr int embeddingsSize = 30; wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize).sparse(true))); auto params = wordEmbeddings->parameters(); _sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end()); linear = register_module("linear", torch::nn::Linear(getContextSize()*embeddingsSize, nbOutputs)); params = linear->parameters(); _denseParameters.insert(_denseParameters.end(), params.begin(), params.end()); } std::vector<torch::Tensor> & ConcatWordsNetworkImpl::denseParameters() { return _denseParameters; } std::vector<torch::Tensor> & ConcatWordsNetworkImpl::sparseParameters() { return _sparseParameters; } torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input) { // input dim = {batch, sequence, embeddings} auto wordsAsEmb = wordEmbeddings(input); // reshaped dim = {batch, sequence of embeddings} auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)}); auto res = torch::softmax(linear(reshaped), reshaped.dim() == 2 ? 1 : 0); return res; }