Skip to content
Snippets Groups Projects
ConfigDataset.cpp 3.61 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "ConfigDataset.hpp"
    
    ConfigDataset::ConfigDataset(std::filesystem::path dir)
    
      for (auto & entry : std::filesystem::directory_iterator(dir))
        if (entry.is_regular_file())
        {
    
          auto stem = entry.path().stem().string();
          if (stem == "extracted")
    
          auto state = util::split(stem, '_')[0];
          auto splited = util::split(util::split(stem, '_')[1], '-');
    
          int fileSize = 1 + std::stoi(splited[1]) - std::stoi(splited[0]);
          size_ += fileSize;
          if (!holders.count(state))
          {
            holders.emplace(state, Holder(state));
            order.emplace_back(state);
          }
          holders.at(state).addFile(entry.path().string(), fileSize);
    
        }
    }
    
    c10::optional<std::size_t> ConfigDataset::size() const
    {
      return size_;
    }
    
    c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::get_batch(std::size_t batchSize)
    
      std::random_shuffle(order.begin(), order.end());
    
      for (auto & state : order)
    
        if (nbToGive.at(state) > 0)
        {
          nbToGive.at(state)--;
          auto res = holders.at(state).get_batch(batchSize);
          if (res.has_value())
            return res;
          else
            nbToGive.at(state) = 0;
        }
    
      computeNbToGive();
    
      for (auto & state : order)
    
        if (nbToGive.at(state) > 0)
        {
          nbToGive.at(state)--;
          auto res = holders.at(state).get_batch(batchSize);
          if (res.has_value())
            return res;
          else
            nbToGive.at(state) = 0;
        }
    
      return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
    
      for (auto & it : holders)
        it.second.reset();
    
      computeNbToGive();
    
    }
    
    void ConfigDataset::load(torch::serialize::InputArchive &)
    
    void ConfigDataset::save(torch::serialize::OutputArchive &) const
    
    void ConfigDataset::Holder::addFile(std::string filename, int filesize)
    {
      files.emplace_back(filename);
      size_ += filesize;
    }
    
    void ConfigDataset::Holder::reset()
    {
      std::random_shuffle(files.begin(), files.end());
      loadedTensorIndex = 0;
      nextIndexToGive = 0;
      nbGiven = 0;
      torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
    }
    
    c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize)
    {
      if (loadedTensorIndex >= (int)files.size())
        return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
      if (nextIndexToGive >= loadedTensor.size(0))
      {
        loadedTensorIndex++;
        if (loadedTensorIndex >= (int)files.size())
          return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
        nextIndexToGive = 0;
        torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
      }
    
      int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive);
      nbGiven += nbElementsToGive;
      auto batch = loadedTensor.narrow(0, nextIndexToGive, nbElementsToGive);
      nextIndexToGive += nbElementsToGive;
      return std::make_tuple(batch.narrow(1, 0, batch.size(1)-1), batch.narrow(1, batch.size(1)-1, 1), state);
    }
    
    ConfigDataset::Holder::Holder(std::string state) : state(state)
    {
    }
    
    std::size_t ConfigDataset::Holder::size() const
    {
      return size_;
    }
    
    std::size_t ConfigDataset::Holder::sizeLeft() const
    {
      return size_-nbGiven;
    }
    
    void ConfigDataset::computeNbToGive()
    {
      int smallestSize = std::numeric_limits<int>::max();
      for (auto & it : holders)
      {
        int sizeLeft = it.second.sizeLeft();
        if (sizeLeft > 0 and sizeLeft < smallestSize)
          smallestSize = sizeLeft;
      }
      for (auto & it : holders)
        nbToGive[it.first] = std::floor(1.0*it.second.sizeLeft()/smallestSize);
    }