Skip to content
Snippets Groups Projects
TestNetwork.hpp 413 B
#ifndef TESTNETWORK__H
#define TESTNETWORK__H

#include <torch/torch.h>
#include "Config.hpp"

class TestNetworkImpl : public torch::nn::Module
{
  private :

  torch::nn::Embedding wordEmbeddings{nullptr};
  torch::nn::Linear linear{nullptr};
  int focusedIndex;

  public :

  TestNetworkImpl(int nbOutputs, int focusedIndex);
  torch::Tensor forward(torch::Tensor input);
};
TORCH_MODULE(TestNetwork);

#endif