#include "ConcatWordsNetwork.hpp" ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements) { constexpr int embeddingsSize = 100; setLeftBorder(leftBorder); setRightBorder(rightBorder); setNbStackElements(nbStackElements); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize).sparse(false))); auto params = wordEmbeddings->parameters(); _denseParameters.insert(_denseParameters.end(), params.begin(), params.end()); linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500)); params = linear1->parameters(); _denseParameters.insert(_denseParameters.end(), params.begin(), params.end()); linear2 = register_module("linear2", torch::nn::Linear(500, nbOutputs)); params = linear2->parameters(); _denseParameters.insert(_denseParameters.end(), params.begin(), params.end()); } std::vector<torch::Tensor> & ConcatWordsNetworkImpl::denseParameters() { return _denseParameters; } std::vector<torch::Tensor> & ConcatWordsNetworkImpl::sparseParameters() { return _sparseParameters; } torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input) { // input dim = {batch, sequence, embeddings} auto wordsAsEmb = wordEmbeddings(input); // reshaped dim = {batch, sequence of embeddings} auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)}); return linear2(torch::relu(linear1(reshaped))); }