#ifndef CONFIGDATASET__H #define CONFIGDATASET__H #include <torch/torch.h> #include "Config.hpp" class ConfigDataset : public torch::data::Dataset<ConfigDataset> { private : torch::Tensor data; std::size_t size_{0}; std::size_t contextSize{0}; public : explicit ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes); torch::optional<size_t> size() const override; torch::data::Example<> get(size_t index) override; }; #endif