#ifndef CONFIGDATASET__H #define CONFIGDATASET__H #include <torch/torch.h> #include "Config.hpp" class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDataset, std::tuple<torch::Tensor,torch::Tensor,std::string>> { private : struct Holder { std::string state; std::vector<std::string> files; torch::Tensor loadedTensor; int loadedTensorIndex{0}; int nextIndexToGive{0}; std::size_t size_{0}; std::size_t nbGiven{0}; Holder(std::string state); void addFile(std::string filename, int filesize); void reset(); std::size_t size() const; std::size_t sizeLeft() const; c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> get_batch(std::size_t batchSize); }; private : std::size_t size_{0}; std::map<std::string,Holder> holders; std::map<std::string,int> nbToGive; std::vector<std::string> order; public : explicit ConfigDataset(std::filesystem::path dir); c10::optional<std::size_t> size() const override; c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> get_batch(std::size_t batchSize) override; void reset() override; void load(torch::serialize::InputArchive &) override; void save(torch::serialize::OutputArchive &) const override; void computeNbToGive(); }; #endif