diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp index 2b42eefe624646cb17a7c869d881c5289801ad25..30954a3c816e71d8ca9f5778eb2c1ad9dcce1316 100644 --- a/torch_modules/src/ConfigDataset.cpp +++ b/torch_modules/src/ConfigDataset.cpp @@ -94,6 +94,7 @@ void ConfigDataset::Holder::reset() nextIndexToGive = 0; nbGiven = 0; torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device); + loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), at::kLong)); } c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize) @@ -107,6 +108,7 @@ c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>(); nextIndexToGive = 0; torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device); + loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), at::kLong)); } int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive);