Newer
Older
#ifndef CONFIGDATASET__H
#define CONFIGDATASET__H
#include <torch/torch.h>
#include "Config.hpp"
class ConfigDataset : public torch::data::Dataset<ConfigDataset>
{
private :
torch::Tensor data;
std::size_t size_{0};
std::size_t contextSize{0};
explicit ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes);
torch::optional<size_t> size() const override;
torch::data::Example<> get(size_t index) override;
};
#endif