Skip to content
Snippets Groups Projects
ConcatWordsNetwork.cpp 1.6 KiB
Newer Older
#include "ConcatWordsNetwork.hpp"

Franck Dary's avatar
Franck Dary committed
ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
Franck Dary's avatar
Franck Dary committed
  constexpr int embeddingsSize = 100;
  setLeftBorder(leftBorder);
  setRightBorder(rightBorder);
  setNbStackElements(nbStackElements);
Franck Dary's avatar
Franck Dary committed
  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize).sparse(false)));
  auto params = wordEmbeddings->parameters();
Franck Dary's avatar
Franck Dary committed
  _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
  linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500));
  params = linear1->parameters();
  _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
  linear2 = register_module("linear2", torch::nn::Linear(500, nbOutputs));
  params = linear2->parameters();
  _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
}

std::vector<torch::Tensor> & ConcatWordsNetworkImpl::denseParameters()
{
  return _denseParameters;
}

std::vector<torch::Tensor> & ConcatWordsNetworkImpl::sparseParameters()
{
  return _sparseParameters;
}

torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
{
  // input dim = {batch, sequence, embeddings}
  auto wordsAsEmb = wordEmbeddings(input);
  // reshaped dim = {batch, sequence of embeddings}
  auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)});

  return linear2(torch::relu(linear1(reshaped)));