diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp index e8f40befef18e5b44173ad4aec91d3d3565e3171..8fa79465fabf6a5d3862c72ac50cd7e8bec1e649 100644 --- a/torch_modules/src/ConfigDataset.cpp +++ b/torch_modules/src/ConfigDataset.cpp @@ -93,8 +93,8 @@ void ConfigDataset::Holder::reset() loadedTensorIndex = 0; nextIndexToGive = 0; nbGiven = 0; - torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::getDevice()); - loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::getDevice()))); + torch::load(loadedTensor, files[loadedTensorIndex]); + loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong))).to(NeuralNetworkImpl::getDevice()); } c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize) @@ -107,8 +107,8 @@ c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset if (loadedTensorIndex >= (int)files.size()) return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>(); nextIndexToGive = 0; - torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::getDevice()); - loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::getDevice()))); + torch::load(loadedTensor, files[loadedTensorIndex]); + loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong))).to(NeuralNetworkImpl::getDevice()); } int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive);