-
Franck Dary authoredFranck Dary authored
TestNetwork.hpp 413 B
#ifndef TESTNETWORK__H
#define TESTNETWORK__H
#include <torch/torch.h>
#include "Config.hpp"
class TestNetworkImpl : public torch::nn::Module
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear{nullptr};
int focusedIndex;
public :
TestNetworkImpl(int nbOutputs, int focusedIndex);
torch::Tensor forward(torch::Tensor input);
};
TORCH_MODULE(TestNetwork);
#endif