From ee3d2d5e18fbf3eb81c0ba6070ec5d660fdd6f57 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 5 Mar 2021 09:40:56 +0100 Subject: [PATCH] Corrected bug where tensor was not initialized to the correct device --- torch_modules/src/LossFunction.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_modules/src/LossFunction.cpp b/torch_modules/src/LossFunction.cpp index e90b901..2f8f1be 100644 --- a/torch_modules/src/LossFunction.cpp +++ b/torch_modules/src/LossFunction.cpp @@ -1,5 +1,6 @@ #include "LossFunction.hpp" #include "util.hpp" +#include "NeuralNetwork.hpp" void LossFunction::init(std::string name) { @@ -50,13 +51,13 @@ torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std:: if (index == 0 or index == 2 or index == 4) { - auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong)); + auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device)); gold[0] = goldIndexes.at(0); return gold; } if (index == 1 or index == 3) { - auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong)); + auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device)); for (auto goldIndex : goldIndexes) gold[goldIndex] = 1; return gold; -- GitLab