Skip to content
Snippets Groups Projects
ConfigDataset.hpp 1.28 KiB
#ifndef CONFIGDATASET__H
#define CONFIGDATASET__H

#include <torch/torch.h>
#include "Config.hpp"

class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDataset, std::tuple<torch::Tensor,torch::Tensor,std::string>>
{
  private :

  struct Holder
  {
    std::string state;
    std::vector<std::string> files;
    torch::Tensor loadedTensor;
    int loadedTensorIndex{0};
    int nextIndexToGive{0};
    std::size_t size_{0};
    std::size_t nbGiven{0};

    Holder(std::string state);
    void addFile(std::string filename, int filesize);
    void reset();
    std::size_t size() const;
    std::size_t sizeLeft() const;
    c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> get_batch(std::size_t batchSize);
  };

  private :

  std::size_t size_{0};
  std::map<std::string,Holder> holders;
  std::map<std::string,int> nbToGive;
  std::vector<std::string> order;

  public :

  explicit ConfigDataset(std::filesystem::path dir);
  c10::optional<std::size_t> size() const override;
  c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> get_batch(std::size_t batchSize) override;
  void reset() override;
  void load(torch::serialize::InputArchive &) override;
  void save(torch::serialize::OutputArchive &) const override;
  void computeNbToGive();
};

#endif