-
Franck Dary authoredFranck Dary authored
ConfigDataset.hpp 1.28 KiB
#ifndef CONFIGDATASET__H
#define CONFIGDATASET__H
#include <torch/torch.h>
#include "Config.hpp"
class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDataset, std::tuple<torch::Tensor,torch::Tensor,std::string>>
{
private :
struct Holder
{
std::string state;
std::vector<std::string> files;
torch::Tensor loadedTensor;
int loadedTensorIndex{0};
int nextIndexToGive{0};
std::size_t size_{0};
std::size_t nbGiven{0};
Holder(std::string state);
void addFile(std::string filename, int filesize);
void reset();
std::size_t size() const;
std::size_t sizeLeft() const;
c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> get_batch(std::size_t batchSize);
};
private :
std::size_t size_{0};
std::map<std::string,Holder> holders;
std::map<std::string,int> nbToGive;
std::vector<std::string> order;
public :
explicit ConfigDataset(std::filesystem::path dir);
c10::optional<std::size_t> size() const override;
c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> get_batch(std::size_t batchSize) override;
void reset() override;
void load(torch::serialize::InputArchive &) override;
void save(torch::serialize::OutputArchive &) const override;
void computeNbToGive();
};
#endif