Skip to content
Snippets Groups Projects
OneWordNetwork.hpp 369 B
Newer Older
#ifndef ONEWORDNETWORK__H
#define ONEWORDNETWORK__H

#include "NeuralNetwork.hpp"

class OneWordNetworkImpl : public NeuralNetworkImpl
{
  private :

  torch::nn::Embedding wordEmbeddings{nullptr};
  torch::nn::Linear linear{nullptr};

  public :

  OneWordNetworkImpl(int nbOutputs, int focusedIndex);
  torch::Tensor forward(torch::Tensor input) override;
};

#endif