Skip to content
Snippets Groups Projects
ConcatWordsNetwork.cpp 1.31 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "ConcatWordsNetwork.hpp"
    
    ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs)
    {
      constexpr int embeddingsSize = 30;
    
      wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize).sparse(true)));
      auto params = wordEmbeddings->parameters();
      _sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end());
    
      linear = register_module("linear", torch::nn::Linear(getContextSize()*embeddingsSize, nbOutputs));
      params = linear->parameters();
      _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
    }
    
    std::vector<torch::Tensor> & ConcatWordsNetworkImpl::denseParameters()
    {
      return _denseParameters;
    }
    
    std::vector<torch::Tensor> & ConcatWordsNetworkImpl::sparseParameters()
    {
      return _sparseParameters;
    }
    
    torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
    {
      // input dim = {batch, sequence, embeddings}
      auto wordsAsEmb = wordEmbeddings(input);
      // reshaped dim = {batch, sequence of embeddings}
      auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)});
    
      auto res = torch::softmax(linear(reshaped), reshaped.dim() == 2 ? 1 : 0);
    
      return res;
    }