-
Franck Dary authoredFranck Dary authored
ConfigDataset.cpp 3.70 KiB
#include "ConfigDataset.hpp"
#include "NeuralNetwork.hpp"
ConfigDataset::ConfigDataset(std::filesystem::path dir)
{
for (auto & entry : std::filesystem::directory_iterator(dir))
if (entry.is_regular_file())
{
auto stem = entry.path().stem().string();
if (stem == "extracted")
continue;
auto state = util::split(stem, '_')[0];
auto splited = util::split(util::split(stem, '_')[1], '-');
int fileSize = 1 + std::stoi(splited[1]) - std::stoi(splited[0]);
size_ += fileSize;
if (!holders.count(state))
{
holders.emplace(state, Holder(state));
order.emplace_back(state);
}
holders.at(state).addFile(entry.path().string(), fileSize);
}
}
c10::optional<std::size_t> ConfigDataset::size() const
{
return size_;
}
c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::get_batch(std::size_t batchSize)
{
std::random_shuffle(order.begin(), order.end());
for (auto & state : order)
{
if (nbToGive.at(state) > 0)
{
nbToGive.at(state)--;
auto res = holders.at(state).get_batch(batchSize);
if (res.has_value())
return res;
else
nbToGive.at(state) = 0;
}
}
computeNbToGive();
for (auto & state : order)
{
if (nbToGive.at(state) > 0)
{
nbToGive.at(state)--;
auto res = holders.at(state).get_batch(batchSize);
if (res.has_value())
return res;
else
nbToGive.at(state) = 0;
}
}
return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
}
void ConfigDataset::reset()
{
for (auto & it : holders)
it.second.reset();
computeNbToGive();
}
void ConfigDataset::load(torch::serialize::InputArchive &)
{
}
void ConfigDataset::save(torch::serialize::OutputArchive &) const
{
}
void ConfigDataset::Holder::addFile(std::string filename, int filesize)
{
files.emplace_back(filename);
size_ += filesize;
}
void ConfigDataset::Holder::reset()
{
std::random_shuffle(files.begin(), files.end());
loadedTensorIndex = 0;
nextIndexToGive = 0;
nbGiven = 0;
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
}
c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize)
{
if (loadedTensorIndex >= (int)files.size())
return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
if (nextIndexToGive >= loadedTensor.size(0))
{
loadedTensorIndex++;
if (loadedTensorIndex >= (int)files.size())
return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
nextIndexToGive = 0;
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
}
int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive);
nbGiven += nbElementsToGive;
auto batch = loadedTensor.narrow(0, nextIndexToGive, nbElementsToGive);
nextIndexToGive += nbElementsToGive;
return std::make_tuple(batch.narrow(1, 0, batch.size(1)-1), batch.narrow(1, batch.size(1)-1, 1), state);
}
ConfigDataset::Holder::Holder(std::string state) : state(state)
{
}
std::size_t ConfigDataset::Holder::size() const
{
return size_;
}
std::size_t ConfigDataset::Holder::sizeLeft() const
{
return size_-nbGiven;
}
void ConfigDataset::computeNbToGive()
{
int smallestSize = std::numeric_limits<int>::max();
for (auto & it : holders)
{
int sizeLeft = it.second.sizeLeft();
if (sizeLeft > 0 and sizeLeft < smallestSize)
smallestSize = sizeLeft;
}
for (auto & it : holders)
{
nbToGive[it.first] = std::max<int>(1,std::floor(1.0*it.second.sizeLeft()/smallestSize));
if (it.second.sizeLeft() == 0)
nbToGive[it.first] = 0;
}
}