#include "ConfigDataset.hpp" #include "NeuralNetwork.hpp" ConfigDataset::ConfigDataset(std::filesystem::path dir) { for (auto & entry : std::filesystem::directory_iterator(dir)) if (entry.is_regular_file()) { auto stem = entry.path().stem().string(); if (stem == "extracted") continue; auto state = util::split(stem, '_')[0]; auto splited = util::split(util::split(stem, '_')[1], '-'); int fileSize = 1 + std::stoi(splited[1]) - std::stoi(splited[0]); size_ += fileSize; if (!holders.count(state)) { holders.emplace(state, Holder(state)); order.emplace_back(state); } holders.at(state).addFile(entry.path().string(), fileSize); } } c10::optional<std::size_t> ConfigDataset::size() const { return size_; } c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::get_batch(std::size_t batchSize) { std::random_shuffle(order.begin(), order.end()); for (auto & state : order) { if (nbToGive.at(state) > 0) { nbToGive.at(state)--; auto res = holders.at(state).get_batch(batchSize); if (res.has_value()) return res; else nbToGive.at(state) = 0; } } computeNbToGive(); for (auto & state : order) { if (nbToGive.at(state) > 0) { nbToGive.at(state)--; auto res = holders.at(state).get_batch(batchSize); if (res.has_value()) return res; else nbToGive.at(state) = 0; } } return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>(); } void ConfigDataset::reset() { for (auto & it : holders) it.second.reset(); computeNbToGive(); } void ConfigDataset::load(torch::serialize::InputArchive &) { } void ConfigDataset::save(torch::serialize::OutputArchive &) const { } void ConfigDataset::Holder::addFile(std::string filename, int filesize) { files.emplace_back(filename); size_ += filesize; } void ConfigDataset::Holder::reset() { std::random_shuffle(files.begin(), files.end()); loadedTensorIndex = 0; nextIndexToGive = 0; nbGiven = 0; torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device); } c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize) { if (loadedTensorIndex >= (int)files.size()) return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>(); if (nextIndexToGive >= loadedTensor.size(0)) { loadedTensorIndex++; if (loadedTensorIndex >= (int)files.size()) return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>(); nextIndexToGive = 0; torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device); } int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive); nbGiven += nbElementsToGive; auto batch = loadedTensor.narrow(0, nextIndexToGive, nbElementsToGive); nextIndexToGive += nbElementsToGive; return std::make_tuple(batch.narrow(1, 0, batch.size(1)-1), batch.narrow(1, batch.size(1)-1, 1), state); } ConfigDataset::Holder::Holder(std::string state) : state(state) { } std::size_t ConfigDataset::Holder::size() const { return size_; } std::size_t ConfigDataset::Holder::sizeLeft() const { return size_-nbGiven; } void ConfigDataset::computeNbToGive() { int smallestSize = std::numeric_limits<int>::max(); for (auto & it : holders) { int sizeLeft = it.second.sizeLeft(); if (sizeLeft > 0 and sizeLeft < smallestSize) smallestSize = sizeLeft; } for (auto & it : holders) { nbToGive[it.first] = std::max<int>(1,std::floor(1.0*it.second.sizeLeft()/smallestSize)); if (it.second.sizeLeft() == 0) nbToGive[it.first] = 0; } }