Skip to content
Snippets Groups Projects
ConfigDataset.cpp 2.56 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "ConfigDataset.hpp"
    
    ConfigDataset::ConfigDataset(std::filesystem::path dir)
    
      for (auto & entry : std::filesystem::directory_iterator(dir))
        if (entry.is_regular_file())
        {
    
          auto stem = entry.path().stem().string();
          if (stem == "extracted")
    
          auto state = util::split(stem, '_')[0];
          auto splited = util::split(util::split(stem, '_')[1], '-');
          exampleLocations.emplace_back(std::make_tuple(std::stoi(splited[0]), std::stoi(splited[1]), entry.path(), state));
    
          size_ += 1 + std::get<1>(exampleLocations.back()) - std::get<0>(exampleLocations.back());
        }
    
    }
    
    c10::optional<std::size_t> ConfigDataset::size() const
    {
      return size_;
    }
    
    c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::get_batch(std::size_t batchSize)
    
    {
      if (!loadedTensorIndex.has_value())
      {
        loadedTensorIndex = 0;
        nextIndexToGive = 0;
        torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device);
      }
      if ((int)nextIndexToGive >= loadedTensor.size(0))
    
        nextIndexToGive = 0;
        loadedTensorIndex = loadedTensorIndex.value() + 1;
    
        if (loadedTensorIndex >= exampleLocations.size())
    
          return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
    
    
        torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device);
    
      std::tuple<torch::Tensor, torch::Tensor, std::string> batch;
    
      if ((int)nextIndexToGive + (int)batchSize < loadedTensor.size(0))
      {
    
        std::get<0>(batch) = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, 0, loadedTensor.size(1)-1);
        std::get<1>(batch) = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, loadedTensor.size(1)-1, 1);
    
        std::get<0>(batch) = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, 0, loadedTensor.size(1)-1);
        std::get<1>(batch) = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, loadedTensor.size(1)-1, 1);
    
      std::get<2>(batch) = std::get<3>(exampleLocations[loadedTensorIndex.value()]);
    
    
    void ConfigDataset::reset()
    {
      std::random_shuffle(exampleLocations.begin(), exampleLocations.end());
      loadedTensorIndex = std::optional<std::size_t>();
      nextIndexToGive = 0;
    }
    
    void ConfigDataset::load(torch::serialize::InputArchive &)
    
    void ConfigDataset::save(torch::serialize::OutputArchive &) const