Skip to content
Snippets Groups Projects
MLP.hpp 398 B
#ifndef MLP__H
#define MLP__H

#include <torch/torch.h>

class MLPImpl : public torch::nn::Module
{
  private :

  std::vector<torch::nn::Linear> layers;
  std::vector<torch::nn::Dropout> dropouts;
  std::size_t outSize{0};

  public :

  MLPImpl(int inputSize, std::string definition);
  torch::Tensor forward(torch::Tensor input);
  std::size_t outputSize() const;
};
TORCH_MODULE(MLP);

#endif