Skip to content
Snippets Groups Projects
ConcatWordsNetwork.cpp 1002 B
Newer Older
  • Learn to ignore specific revisions
  • #include "ConcatWordsNetwork.hpp"
    
    
    ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, const std::vector<int> & bufferContext, const std::vector<int> & stackContext)
    
      constexpr int embeddingsSize = 64;
      constexpr int hiddenSize = 500;
    
      setBufferContext(bufferContext);
      setStackContext(stackContext);
    
      setColumns({"FORM", "UPOS"});
    
      wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
    
      linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, hiddenSize));
      linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
    
      dropout = register_module("dropout", torch::nn::Dropout(0.3));
    
    }
    
    torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
    {
    
      if (input.dim() == 1)
        input = input.unsqueeze(0);
      auto wordsAsEmb = dropout(wordEmbeddings(input).view({input.size(0), -1}));
      return linear2(torch::relu(linear1(wordsAsEmb)));