#ifndef CONCATWORDSNETWORK__H #define CONCATWORDSNETWORK__H #include "NeuralNetwork.hpp" class ConcatWordsNetworkImpl : public NeuralNetworkImpl { private : torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Linear linear1{nullptr}; torch::nn::Linear linear2{nullptr}; torch::nn::Dropout dropout{nullptr}; public : ConcatWordsNetworkImpl(int nbOutputs, const std::vector<int> & bufferContext, const std::vector<int> & stackContext); torch::Tensor forward(torch::Tensor input) override; }; #endif