Newer
Older
Franck Dary
committed
#include "NeuralNetwork.hpp"
Franck Dary
committed
ConfigDataset::ConfigDataset(std::filesystem::path dir)
Franck Dary
committed
for (auto & entry : std::filesystem::directory_iterator(dir))
if (entry.is_regular_file())
{
auto stem = entry.path().stem().string();
if (stem == "extracted")
Franck Dary
committed
continue;
auto state = util::split(stem, '_')[0];
auto splited = util::split(util::split(stem, '_')[1], '-');
exampleLocations.emplace_back(std::make_tuple(std::stoi(splited[0]), std::stoi(splited[1]), entry.path(), state));
Franck Dary
committed
size_ += 1 + std::get<1>(exampleLocations.back()) - std::get<0>(exampleLocations.back());
}
Franck Dary
committed
}
c10::optional<std::size_t> ConfigDataset::size() const
{
return size_;
}
c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::get_batch(std::size_t batchSize)
Franck Dary
committed
{
if (!loadedTensorIndex.has_value())
{
loadedTensorIndex = 0;
nextIndexToGive = 0;
torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device);
}
if ((int)nextIndexToGive >= loadedTensor.size(0))
Franck Dary
committed
nextIndexToGive = 0;
loadedTensorIndex = loadedTensorIndex.value() + 1;
if (loadedTensorIndex >= exampleLocations.size())
return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
Franck Dary
committed
torch::load(loadedTensor, std::get<2>(exampleLocations[loadedTensorIndex.value()]), NeuralNetworkImpl::device);
std::tuple<torch::Tensor, torch::Tensor, std::string> batch;
Franck Dary
committed
if ((int)nextIndexToGive + (int)batchSize < loadedTensor.size(0))
{
std::get<0>(batch) = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, 0, loadedTensor.size(1)-1);
std::get<1>(batch) = loadedTensor.narrow(0, nextIndexToGive, batchSize).narrow(1, loadedTensor.size(1)-1, 1);
Franck Dary
committed
nextIndexToGive += batchSize;
}
else
{
std::get<0>(batch) = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, 0, loadedTensor.size(1)-1);
std::get<1>(batch) = loadedTensor.narrow(0, nextIndexToGive, loadedTensor.size(0)-nextIndexToGive).narrow(1, loadedTensor.size(1)-1, 1);
Franck Dary
committed
nextIndexToGive = loadedTensor.size(0);
}
std::get<2>(batch) = std::get<3>(exampleLocations[loadedTensorIndex.value()]);
Franck Dary
committed
return batch;
Franck Dary
committed
void ConfigDataset::reset()
{
std::random_shuffle(exampleLocations.begin(), exampleLocations.end());
loadedTensorIndex = std::optional<std::size_t>();
nextIndexToGive = 0;
}
void ConfigDataset::load(torch::serialize::InputArchive &)
Franck Dary
committed
void ConfigDataset::save(torch::serialize::OutputArchive &) const