Skip to content
Snippets Groups Projects
Commit 81aecb05 authored by Franck Dary's avatar Franck Dary
Browse files

Shuffling batches Tensor in ConfigDataset

parent 6cdb5e16
No related branches found
No related tags found
No related merge requests found
......@@ -94,6 +94,7 @@ void ConfigDataset::Holder::reset()
nextIndexToGive = 0;
nbGiven = 0;
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), at::kLong));
}
c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize)
......@@ -107,6 +108,7 @@ c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset
return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
nextIndexToGive = 0;
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), at::kLong));
}
int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment