#include "ConfigDataset.hpp" #include "NeuralNetwork.hpp" ConfigDataset::ConfigDataset(std::filesystem::path dir) { for (auto & entry : std::filesystem::directory_iterator(dir)) if (entry.is_regular_file()) { auto stem = entry.path().stem().string(); if (stem == "extracted") continue; auto state = util::split(stem, '_')[0]; auto splited = util::split(util::split(stem, '_')[1], '-'); exampleLocations.emplace_back(std::make_tuple(std::stoi(splited[0]), std::stoi(splited[1]), entry.path(), state)); size_ += 1 + std::get<1>(exampleLocations.back()) - std::get<0>(exampleLocations.back()); } } c10::optional<std::size_t> ConfigDataset::size() const { return size_; } c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::get_batch(std::size_t batchSize) { if (!loadedTensorIndex.has_value()) { loadedTensorIndex = 0; nextIndexToGive = 0; torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device); } if ((int)nextIndexToGive >= loadedTensor.size(0)) { nextIndexToGive = 0; loadedTensorIndex = loadedTensorIndex.value() + 1; if (loadedTensorIndex >= exampleLocations.size()) return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>(); torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device); } std::tuple<torch::Tensor, torch::Tensor, std::string> batch; if ((int)nextIndexToGive + (int)batchSize < loadedTensor.size(0)) { std::get<0>(batch) = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, 0, loadedTensor.size(1)-1); std::get<1>(batch) = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, loadedTensor.size(1)-1, 1); nextIndexToGive += batchSize; } else { std::get<0>(batch) = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, 0, loadedTensor.size(1)-1); std::get<1>(batch) = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, loadedTensor.size(1)-1, 1); nextIndexToGive = loadedTensor.size(0); } std::get<2>(batch) = std::get<3>(exampleLocations[loadedTensorIndex.value()]); return batch; } void ConfigDataset::reset() { std::random_shuffle(exampleLocations.begin(), exampleLocations.end()); loadedTensorIndex = std::optional<std::size_t>(); nextIndexToGive = 0; } void ConfigDataset::load(torch::serialize::InputArchive &) { } void ConfigDataset::save(torch::serialize::OutputArchive &) const { }