Skip to content
Snippets Groups Projects
OneWordNetwork.cpp 1.01 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "OneWordNetwork.hpp"
    
    OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex)
    
    Franck Dary's avatar
    Franck Dary committed
      constexpr int embeddingsSize = 30;
    
    Franck Dary's avatar
    Franck Dary committed
      wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize)));
    
    Franck Dary's avatar
    Franck Dary committed
      linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs));
    
    Franck Dary's avatar
    Franck Dary committed
      int leftBorder = 0;
      int rightBorder = 0;
      if (focusedIndex < 0)
        leftBorder = -focusedIndex;
      if (focusedIndex > 0)
        rightBorder = focusedIndex;
    
      this->focusedIndex = focusedIndex <= 0 ? 0 : focusedIndex;
    
      setLeftBorder(leftBorder);
      setRightBorder(rightBorder);
      setNbStackElements(0);
    
    torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input)
    
    Franck Dary's avatar
    Franck Dary committed
      // input dim = {batch, sequence, embeddings}
      auto wordsAsEmb = wordEmbeddings(input);
    
    Franck Dary's avatar
    Franck Dary committed
      auto reshaped = wordsAsEmb;
    
    Franck Dary's avatar
    Franck Dary committed
      // reshaped dim = {sequence, batch, embeddings}
    
    Franck Dary's avatar
    Franck Dary committed
      if (reshaped.dim() == 3)
        reshaped = wordsAsEmb.permute({1,0,2});
    
      auto res = linear(reshaped[focusedIndex]);
    
    Franck Dary's avatar
    Franck Dary committed
      return res;