Newer
Older
Franck Dary
committed
#include "NeuralNetwork.hpp"
Franck Dary
committed
ConfigDataset::ConfigDataset(std::filesystem::path dir)
Franck Dary
committed
for (auto & entry : std::filesystem::directory_iterator(dir))
if (entry.is_regular_file())
{
auto splited = util::split(entry.path().stem().string(), '-');
if (splited.size() != 2)
continue;
exampleLocations.emplace_back(std::make_tuple(std::stoi(splited[0]), std::stoi(splited[1]), entry.path()));
size_ += 1 + std::get<1>(exampleLocations.back()) - std::get<0>(exampleLocations.back());
}
}
c10::optional<std::size_t> ConfigDataset::size() const
{
return size_;
}
Franck Dary
committed
c10::optional<std::pair<torch::Tensor,torch::Tensor>> ConfigDataset::get_batch(std::size_t batchSize)
{
if (!loadedTensorIndex.has_value())
{
loadedTensorIndex = 0;
nextIndexToGive = 0;
torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device);
}
if ((int)nextIndexToGive >= loadedTensor.size(0))
Franck Dary
committed
nextIndexToGive = 0;
loadedTensorIndex = loadedTensorIndex.value() + 1;
if (loadedTensorIndex >= exampleLocations.size())
return c10::optional<std::pair<torch::Tensor,torch::Tensor>>();
torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device);
Franck Dary
committed
std::pair<torch::Tensor, torch::Tensor> batch;
if ((int)nextIndexToGive + (int)batchSize < loadedTensor.size(0))
{
batch.first = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, 0, loadedTensor.size(1)-1);
batch.second = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, loadedTensor.size(1)-1, 1);
nextIndexToGive += batchSize;
}
else
{
batch.first = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, 0, loadedTensor.size(1)-1);
batch.second = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, loadedTensor.size(1)-1, 1);
nextIndexToGive = loadedTensor.size(0);
}
return batch;
Franck Dary
committed
void ConfigDataset::reset()
{
std::random_shuffle(exampleLocations.begin(), exampleLocations.end());
loadedTensorIndex = std::optional<std::size_t>();
nextIndexToGive = 0;
}
void ConfigDataset::load(torch::serialize::InputArchive &)
Franck Dary
committed
void ConfigDataset::save(torch::serialize::OutputArchive &) const