Skip to content
Snippets Groups Projects
ConcatWordsNetwork.hpp 582 B
Newer Older
#ifndef CONCATWORDSNETWORK__H
#define CONCATWORDSNETWORK__H

#include "NeuralNetwork.hpp"

class ConcatWordsNetworkImpl : public NeuralNetworkImpl
{
  private :

  torch::nn::Embedding wordEmbeddings{nullptr};
  torch::nn::Linear linear{nullptr};

  std::vector<torch::Tensor> _denseParameters;
  std::vector<torch::Tensor> _sparseParameters;

  public :

  ConcatWordsNetworkImpl(int nbOutputs);
  torch::Tensor forward(torch::Tensor input) override;
  std::vector<torch::Tensor> & denseParameters() override;
  std::vector<torch::Tensor> & sparseParameters() override;
};

#endif