Skip to content
Snippets Groups Projects
ConfigDataset.cpp 3.7 KiB
Newer Older
#include "ConfigDataset.hpp"
ConfigDataset::ConfigDataset(std::filesystem::path dir)
  for (auto & entry : std::filesystem::directory_iterator(dir))
    if (entry.is_regular_file())
    {
      auto stem = entry.path().stem().string();
      if (stem == "extracted")
      auto state = util::split(stem, '_')[0];
      auto splited = util::split(util::split(stem, '_')[1], '-');
      int fileSize = 1 + std::stoi(splited[1]) - std::stoi(splited[0]);
      size_ += fileSize;
      if (!holders.count(state))
      {
        holders.emplace(state, Holder(state));
        order.emplace_back(state);
      }
      holders.at(state).addFile(entry.path().string(), fileSize);
    }
}

c10::optional<std::size_t> ConfigDataset::size() const
{
  return size_;
}
c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::get_batch(std::size_t batchSize)
  std::random_shuffle(order.begin(), order.end());

  for (auto & state : order)
    if (nbToGive.at(state) > 0)
    {
      nbToGive.at(state)--;
      auto res = holders.at(state).get_batch(batchSize);
      if (res.has_value())
        return res;
      else
        nbToGive.at(state) = 0;
    }
  computeNbToGive();
  for (auto & state : order)
    if (nbToGive.at(state) > 0)
    {
      nbToGive.at(state)--;
      auto res = holders.at(state).get_batch(batchSize);
      if (res.has_value())
        return res;
      else
        nbToGive.at(state) = 0;
    }
  return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
  for (auto & it : holders)
    it.second.reset();

  computeNbToGive();
}

void ConfigDataset::load(torch::serialize::InputArchive &)
void ConfigDataset::save(torch::serialize::OutputArchive &) const
void ConfigDataset::Holder::addFile(std::string filename, int filesize)
{
  files.emplace_back(filename);
  size_ += filesize;
}

void ConfigDataset::Holder::reset()
{
  std::random_shuffle(files.begin(), files.end());
  loadedTensorIndex = 0;
  nextIndexToGive = 0;
  nbGiven = 0;
  torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
}

c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize)
{
  if (loadedTensorIndex >= (int)files.size())
    return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
  if (nextIndexToGive >= loadedTensor.size(0))
  {
    loadedTensorIndex++;
    if (loadedTensorIndex >= (int)files.size())
      return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
    nextIndexToGive = 0;
    torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
  }

  int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive);
  nbGiven += nbElementsToGive;
  auto batch = loadedTensor.narrow(0, nextIndexToGive, nbElementsToGive);
  nextIndexToGive += nbElementsToGive;
  return std::make_tuple(batch.narrow(1, 0, batch.size(1)-1), batch.narrow(1, batch.size(1)-1, 1), state);
}

ConfigDataset::Holder::Holder(std::string state) : state(state)
{
}

std::size_t ConfigDataset::Holder::size() const
{
  return size_;
}

std::size_t ConfigDataset::Holder::sizeLeft() const
{
  return size_-nbGiven;
}

void ConfigDataset::computeNbToGive()
{
  int smallestSize = std::numeric_limits<int>::max();
  for (auto & it : holders)
  {
    int sizeLeft = it.second.sizeLeft();
    if (sizeLeft > 0 and sizeLeft < smallestSize)
      smallestSize = sizeLeft;
  }
  for (auto & it : holders)
  {
    nbToGive[it.first] = std::max<int>(1,std::floor(1.0*it.second.sizeLeft()/smallestSize));
    if (it.second.sizeLeft() == 0)
      nbToGive[it.first] = 0;
  }