From 04dd8e563a8870bfa003af539001aa9b44922845 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 30 Jun 2020 20:49:56 +0200 Subject: [PATCH] Corrected a bug where a tensor was not send to the correct device --- torch_modules/src/ConfigDataset.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp index 30954a3..91695ad 100644 --- a/torch_modules/src/ConfigDataset.cpp +++ b/torch_modules/src/ConfigDataset.cpp @@ -94,7 +94,7 @@ void ConfigDataset::Holder::reset() nextIndexToGive = 0; nbGiven = 0; torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device); - loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), at::kLong)); + loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device))); } c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize) @@ -108,7 +108,7 @@ c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>(); nextIndexToGive = 0; torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device); - loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), at::kLong)); + loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device))); } int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive); -- GitLab