From ffb2e08b70bc3b264075c68592fa0e5548aae701 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sun, 18 Apr 2021 13:15:01 +0200
Subject: [PATCH] Shuffle dataset on CPU to avoid CUDA sync error

---
 torch_modules/src/ConfigDataset.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp
index e8f40be..8fa7946 100644
--- a/torch_modules/src/ConfigDataset.cpp
+++ b/torch_modules/src/ConfigDataset.cpp
@@ -93,8 +93,8 @@ void ConfigDataset::Holder::reset()
   loadedTensorIndex = 0;
   nextIndexToGive = 0;
   nbGiven = 0;
-  torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::getDevice());
-  loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::getDevice())));
+  torch::load(loadedTensor, files[loadedTensorIndex]);
+  loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong))).to(NeuralNetworkImpl::getDevice());
 }
 
 c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize)
@@ -107,8 +107,8 @@ c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset
     if (loadedTensorIndex >= (int)files.size())
       return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
     nextIndexToGive = 0;
-    torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::getDevice());
-    loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::getDevice())));
+    torch::load(loadedTensor, files[loadedTensorIndex]);
+    loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong))).to(NeuralNetworkImpl::getDevice());
   }
 
   int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive);
-- 
GitLab