Skip to content
Snippets Groups Projects
Commit ffb2e08b authored by Franck Dary's avatar Franck Dary
Browse files

Shuffle dataset on CPU to avoid CUDA sync error

parent 6d627fa1
No related branches found
No related tags found
No related merge requests found
...@@ -93,8 +93,8 @@ void ConfigDataset::Holder::reset() ...@@ -93,8 +93,8 @@ void ConfigDataset::Holder::reset()
loadedTensorIndex = 0; loadedTensorIndex = 0;
nextIndexToGive = 0; nextIndexToGive = 0;
nbGiven = 0; nbGiven = 0;
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::getDevice()); torch::load(loadedTensor, files[loadedTensorIndex]);
loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::getDevice()))); loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong))).to(NeuralNetworkImpl::getDevice());
} }
c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize) c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize)
...@@ -107,8 +107,8 @@ c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset ...@@ -107,8 +107,8 @@ c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset
if (loadedTensorIndex >= (int)files.size()) if (loadedTensorIndex >= (int)files.size())
return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>(); return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
nextIndexToGive = 0; nextIndexToGive = 0;
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::getDevice()); torch::load(loadedTensor, files[loadedTensorIndex]);
loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::getDevice()))); loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong))).to(NeuralNetworkImpl::getDevice());
} }
int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive); int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment