Commit 81aecb05 authored by Franck Dary's avatar Franck Dary
Browse files

Shuffling batches Tensor in ConfigDataset

parent 6cdb5e16
......@@ -94,6 +94,7 @@ void ConfigDataset::Holder::reset()
nextIndexToGive = 0;
nbGiven = 0;
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), at::kLong));
}
c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize)
......@@ -107,6 +108,7 @@ c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset
return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
nextIndexToGive = 0;
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), at::kLong));
}
int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment