Skip to content
Snippets Groups Projects
ConfigDataset.cpp 4.01 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "ConfigDataset.hpp"
    
    ConfigDataset::ConfigDataset(std::filesystem::path dir)
    
      for (auto & entry : std::filesystem::directory_iterator(dir))
        if (entry.is_regular_file())
        {
    
    Franck Dary's avatar
    Franck Dary committed
          auto stem = util::split(entry.path().stem().string(), '.')[0];
    
    Franck Dary's avatar
    Franck Dary committed
          auto underSplit = util::split(stem, '_');
    
    Franck Dary's avatar
    Franck Dary committed
          auto stateAndNbClasses = util::split(util::join("_", std::vector<std::string>(underSplit.begin(), underSplit.end()-1)), '-');
          auto state = stateAndNbClasses[0];
          auto nbClasses = std::stoi(stateAndNbClasses[1]);
    
    Franck Dary's avatar
    Franck Dary committed
          auto splited = util::split(underSplit.back(), '-');
    
          int fileSize = 1 + std::stoi(splited[1]) - std::stoi(splited[0]);
          size_ += fileSize;
          if (!holders.count(state))
          {
    
    Franck Dary's avatar
    Franck Dary committed
            holders.emplace(state, Holder(state, nbClasses));
    
            order.emplace_back(state);
          }
          holders.at(state).addFile(entry.path().string(), fileSize);
    
        }
    }
    
    c10::optional<std::size_t> ConfigDataset::size() const
    {
      return size_;
    }
    
    c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::get_batch(std::size_t batchSize)
    
      std::random_shuffle(order.begin(), order.end());
    
      for (auto & state : order)
    
        if (nbToGive.at(state) > 0)
        {
          nbToGive.at(state)--;
          auto res = holders.at(state).get_batch(batchSize);
          if (res.has_value())
            return res;
          else
            nbToGive.at(state) = 0;
        }
    
      computeNbToGive();
    
      for (auto & state : order)
    
        if (nbToGive.at(state) > 0)
        {
          nbToGive.at(state)--;
          auto res = holders.at(state).get_batch(batchSize);
          if (res.has_value())
            return res;
          else
            nbToGive.at(state) = 0;
        }
    
      return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
    
      for (auto & it : holders)
        it.second.reset();
    
      computeNbToGive();
    
    }
    
    void ConfigDataset::load(torch::serialize::InputArchive &)
    
    void ConfigDataset::save(torch::serialize::OutputArchive &) const
    
    void ConfigDataset::Holder::addFile(std::string filename, int filesize)
    {
      files.emplace_back(filename);
      size_ += filesize;
    }
    
    void ConfigDataset::Holder::reset()
    {
      std::random_shuffle(files.begin(), files.end());
      loadedTensorIndex = 0;
      nextIndexToGive = 0;
      nbGiven = 0;
      torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
    }
    
    c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize)
    {
      if (loadedTensorIndex >= (int)files.size())
        return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
      if (nextIndexToGive >= loadedTensor.size(0))
      {
        loadedTensorIndex++;
        if (loadedTensorIndex >= (int)files.size())
          return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
        nextIndexToGive = 0;
        torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
      }
    
      int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive);
      nbGiven += nbElementsToGive;
      auto batch = loadedTensor.narrow(0, nextIndexToGive, nbElementsToGive);
      nextIndexToGive += nbElementsToGive;
    
    Franck Dary's avatar
    Franck Dary committed
      return std::make_tuple(batch.narrow(1, 0, batch.size(1)-nbClasses), batch.narrow(1, batch.size(1)-nbClasses, nbClasses), state);
    
    Franck Dary's avatar
    Franck Dary committed
    ConfigDataset::Holder::Holder(std::string state, int nbClasses) : state(state), nbClasses(nbClasses)
    
    {
    }
    
    std::size_t ConfigDataset::Holder::size() const
    {
      return size_;
    }
    
    std::size_t ConfigDataset::Holder::sizeLeft() const
    {
      return size_-nbGiven;
    }
    
    void ConfigDataset::computeNbToGive()
    {
      int smallestSize = std::numeric_limits<int>::max();
      for (auto & it : holders)
      {
        int sizeLeft = it.second.sizeLeft();
        if (sizeLeft > 0 and sizeLeft < smallestSize)
          smallestSize = sizeLeft;
      }
      for (auto & it : holders)
    
      {
        nbToGive[it.first] = std::max<int>(1,std::floor(1.0*it.second.sizeLeft()/smallestSize));
        if (it.second.sizeLeft() == 0)
          nbToGive[it.first] = 0;
      }