Skip to content
Snippets Groups Projects
ConcatWordsNetwork.cpp 1.05 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "ConcatWordsNetwork.hpp"
    
    
    Franck Dary's avatar
    Franck Dary committed
    ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
    
    Franck Dary's avatar
    Franck Dary committed
      constexpr int embeddingsSize = 100;
      setLeftBorder(leftBorder);
      setRightBorder(rightBorder);
      setNbStackElements(nbStackElements);
    
      wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize).sparse(true)));
    
    Franck Dary's avatar
    Franck Dary committed
      linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500));
      linear2 = register_module("linear2", torch::nn::Linear(500, nbOutputs));
    
    }
    
    torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
    {
      // input dim = {batch, sequence, embeddings}
      auto wordsAsEmb = wordEmbeddings(input);
      // reshaped dim = {batch, sequence of embeddings}
      auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)});
    
    
      return linear2(torch::relu(linear1(reshaped)));