diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp
index 30954a3c816e71d8ca9f5778eb2c1ad9dcce1316..91695ad8749e4c23afed51445c2ba3c83441a4ed 100644
--- a/torch_modules/src/ConfigDataset.cpp
+++ b/torch_modules/src/ConfigDataset.cpp
@@ -94,7 +94,7 @@ void ConfigDataset::Holder::reset()
   nextIndexToGive = 0;
   nbGiven = 0;
   torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
-  loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), at::kLong));
+  loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device)));
 }
 
 c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize)
@@ -108,7 +108,7 @@ c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset
       return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
     nextIndexToGive = 0;
     torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
-    loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), at::kLong));
+    loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device)));
   }
 
   int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive);