Newer
Older
#ifndef ONEWORDNETWORK__H
#define ONEWORDNETWORK__H
#include "NeuralNetwork.hpp"
class OneWordNetworkImpl : public NeuralNetworkImpl
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear{nullptr};
public :
OneWordNetworkImpl(int nbOutputs, int focusedIndex);
torch::Tensor forward(torch::Tensor input) override;
};
#endif