Newer
Older
Franck Dary
committed
#include "NeuralNetwork.hpp"
Franck Dary
committed
ConfigDataset::ConfigDataset(std::filesystem::path dir)
Franck Dary
committed
for (auto & entry : std::filesystem::directory_iterator(dir))
if (entry.is_regular_file())
{
auto stem = entry.path().stem().string();
if (stem == "extracted")
Franck Dary
committed
continue;
auto state = util::split(stem, '_')[0];
auto splited = util::split(util::split(stem, '_')[1], '-');
int fileSize = 1 + std::stoi(splited[1]) - std::stoi(splited[0]);
size_ += fileSize;
if (!holders.count(state))
{
holders.emplace(state, Holder(state));
order.emplace_back(state);
}
holders.at(state).addFile(entry.path().string(), fileSize);
Franck Dary
committed
}
}
c10::optional<std::size_t> ConfigDataset::size() const
{
return size_;
}
c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::get_batch(std::size_t batchSize)
Franck Dary
committed
{
std::random_shuffle(order.begin(), order.end());
for (auto & state : order)
Franck Dary
committed
{
if (nbToGive.at(state) > 0)
{
nbToGive.at(state)--;
auto res = holders.at(state).get_batch(batchSize);
if (res.has_value())
return res;
else
nbToGive.at(state) = 0;
}
Franck Dary
committed
}
Franck Dary
committed
{
if (nbToGive.at(state) > 0)
{
nbToGive.at(state)--;
auto res = holders.at(state).get_batch(batchSize);
if (res.has_value())
return res;
else
nbToGive.at(state) = 0;
}
Franck Dary
committed
}
return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
Franck Dary
committed
void ConfigDataset::reset()
{
for (auto & it : holders)
it.second.reset();
computeNbToGive();
Franck Dary
committed
}
void ConfigDataset::load(torch::serialize::InputArchive &)
Franck Dary
committed
void ConfigDataset::save(torch::serialize::OutputArchive &) const
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
void ConfigDataset::Holder::addFile(std::string filename, int filesize)
{
files.emplace_back(filename);
size_ += filesize;
}
void ConfigDataset::Holder::reset()
{
std::random_shuffle(files.begin(), files.end());
loadedTensorIndex = 0;
nextIndexToGive = 0;
nbGiven = 0;
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
}
c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize)
{
if (loadedTensorIndex >= (int)files.size())
return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
if (nextIndexToGive >= loadedTensor.size(0))
{
loadedTensorIndex++;
if (loadedTensorIndex >= (int)files.size())
return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
nextIndexToGive = 0;
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
}
int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive);
nbGiven += nbElementsToGive;
auto batch = loadedTensor.narrow(0, nextIndexToGive, nbElementsToGive);
nextIndexToGive += nbElementsToGive;
return std::make_tuple(batch.narrow(1, 0, batch.size(1)-1), batch.narrow(1, batch.size(1)-1, 1), state);
}
ConfigDataset::Holder::Holder(std::string state) : state(state)
{
}
std::size_t ConfigDataset::Holder::size() const
{
return size_;
}
std::size_t ConfigDataset::Holder::sizeLeft() const
{
return size_-nbGiven;
}
void ConfigDataset::computeNbToGive()
{
int smallestSize = std::numeric_limits<int>::max();
for (auto & it : holders)
{
int sizeLeft = it.second.sizeLeft();
if (sizeLeft > 0 and sizeLeft < smallestSize)
smallestSize = sizeLeft;
}
for (auto & it : holders)
{
nbToGive[it.first] = std::max<int>(1,std::floor(1.0*it.second.sizeLeft()/smallestSize));
if (it.second.sizeLeft() == 0)
nbToGive[it.first] = 0;
}