Skip to content
Snippets Groups Projects
Commit 04dd8e56 authored by Franck Dary's avatar Franck Dary
Browse files

Corrected a bug where a tensor was not send to the correct device

parent 81aecb05
Branches
No related tags found
No related merge requests found
......@@ -94,7 +94,7 @@ void ConfigDataset::Holder::reset()
nextIndexToGive = 0;
nbGiven = 0;
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), at::kLong));
loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device)));
}
c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize)
......@@ -108,7 +108,7 @@ c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset
return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
nextIndexToGive = 0;
torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), at::kLong));
loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device)));
}
int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment