Skip to content
Snippets Groups Projects
ConcatWordsNetwork.hpp 460 B
Newer Older
#ifndef CONCATWORDSNETWORK__H
#define CONCATWORDSNETWORK__H

#include "NeuralNetwork.hpp"

class ConcatWordsNetworkImpl : public NeuralNetworkImpl
{
  private :

  torch::nn::Embedding wordEmbeddings{nullptr};
Franck Dary's avatar
Franck Dary committed
  torch::nn::Linear linear1{nullptr};
  torch::nn::Linear linear2{nullptr};
Franck Dary's avatar
Franck Dary committed
  ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
  torch::Tensor forward(torch::Tensor input) override;
};

#endif