Commit ffb2e08b authored by Franck Dary's avatar Franck Dary
Browse files

Shuffle dataset on CPU to avoid CUDA sync error

parent 6d627fa1
......@@ -93,8 +93,8 @@ void ConfigDataset::Holder::reset()
loadedTensorIndex = 0;
nextIndexToGive = 0;
nbGiven = 0;
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::getDevice());
loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::getDevice())));
torch::load(loadedTensor, files[loadedTensorIndex]);
loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong))).to(NeuralNetworkImpl::getDevice());
}
c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize)
......@@ -107,8 +107,8 @@ c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset
if (loadedTensorIndex >= (int)files.size())
return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
nextIndexToGive = 0;
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::getDevice());
loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::getDevice())));
torch::load(loadedTensor, files[loadedTensorIndex]);
loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong))).to(NeuralNetworkImpl::getDevice());
}
int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment