#ifndef ONEWORDNETWORK__H #define ONEWORDNETWORK__H #include "NeuralNetwork.hpp" class OneWordNetworkImpl : public NeuralNetworkImpl { private : torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Linear linear{nullptr}; int focusedIndex; std::vector<torch::Tensor> _denseParameters; std::vector<torch::Tensor> _sparseParameters; public : OneWordNetworkImpl(int nbOutputs, int focusedIndex); torch::Tensor forward(torch::Tensor input) override; std::vector<torch::Tensor> & denseParameters() override; std::vector<torch::Tensor> & sparseParameters() override; }; #endif