Newer
Older
Franck Dary
committed
#include "NeuralNetwork.hpp"
Franck Dary
committed
ConfigDataset::ConfigDataset(std::filesystem::path dir)
Franck Dary
committed
for (auto & entry : std::filesystem::directory_iterator(dir))
if (entry.is_regular_file())
{
auto stem = util::split(entry.path().stem().string(), '.')[0];
if (stem == "extracted")
Franck Dary
committed
continue;
auto stateAndNbClasses = util::split(util::join("_", std::vector<std::string>(underSplit.begin(), underSplit.end()-1)), '-');
auto state = stateAndNbClasses[0];
auto nbClasses = std::stoi(stateAndNbClasses[1]);
int fileSize = 1 + std::stoi(splited[1]) - std::stoi(splited[0]);
size_ += fileSize;
if (!holders.count(state))
{
order.emplace_back(state);
}
holders.at(state).addFile(entry.path().string(), fileSize);
Franck Dary
committed
}
}
c10::optional<std::size_t> ConfigDataset::size() const
{
return size_;
}
c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::get_batch(std::size_t batchSize)
Franck Dary
committed
{
std::random_shuffle(order.begin(), order.end());
for (auto & state : order)
Franck Dary
committed
{
if (nbToGive.at(state) > 0)
{
nbToGive.at(state)--;
auto res = holders.at(state).get_batch(batchSize);
if (res.has_value())
return res;
else
nbToGive.at(state) = 0;
}
Franck Dary
committed
}
Franck Dary
committed
{
if (nbToGive.at(state) > 0)
{
nbToGive.at(state)--;
auto res = holders.at(state).get_batch(batchSize);
if (res.has_value())
return res;
else
nbToGive.at(state) = 0;
}
Franck Dary
committed
}
return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
Franck Dary
committed
void ConfigDataset::reset()
{
for (auto & it : holders)
it.second.reset();
computeNbToGive();
Franck Dary
committed
}
void ConfigDataset::load(torch::serialize::InputArchive &)
Franck Dary
committed
void ConfigDataset::save(torch::serialize::OutputArchive &) const
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
void ConfigDataset::Holder::addFile(std::string filename, int filesize)
{
files.emplace_back(filename);
size_ += filesize;
}
void ConfigDataset::Holder::reset()
{
std::random_shuffle(files.begin(), files.end());
loadedTensorIndex = 0;
nextIndexToGive = 0;
nbGiven = 0;
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
}
c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize)
{
if (loadedTensorIndex >= (int)files.size())
return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
if (nextIndexToGive >= loadedTensor.size(0))
{
loadedTensorIndex++;
if (loadedTensorIndex >= (int)files.size())
return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
nextIndexToGive = 0;
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
}
int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive);
nbGiven += nbElementsToGive;
auto batch = loadedTensor.narrow(0, nextIndexToGive, nbElementsToGive);
nextIndexToGive += nbElementsToGive;
return std::make_tuple(batch.narrow(1, 0, batch.size(1)-nbClasses), batch.narrow(1, batch.size(1)-nbClasses, nbClasses), state);
ConfigDataset::Holder::Holder(std::string state, int nbClasses) : state(state), nbClasses(nbClasses)
{
}
std::size_t ConfigDataset::Holder::size() const
{
return size_;
}
std::size_t ConfigDataset::Holder::sizeLeft() const
{
return size_-nbGiven;
}
void ConfigDataset::computeNbToGive()
{
int smallestSize = std::numeric_limits<int>::max();
for (auto & it : holders)
{
int sizeLeft = it.second.sizeLeft();
if (sizeLeft > 0 and sizeLeft < smallestSize)
smallestSize = sizeLeft;
}
for (auto & it : holders)
{
nbToGive[it.first] = std::max<int>(1,std::floor(1.0*it.second.sizeLeft()/smallestSize));
if (it.second.sizeLeft() == 0)
nbToGive[it.first] = 0;
}