From ee3d2d5e18fbf3eb81c0ba6070ec5d660fdd6f57 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 5 Mar 2021 09:40:56 +0100
Subject: [PATCH] Corrected bug where tensor was not initialized to the correct
 device

---
 torch_modules/src/LossFunction.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/torch_modules/src/LossFunction.cpp b/torch_modules/src/LossFunction.cpp
index e90b901..2f8f1be 100644
--- a/torch_modules/src/LossFunction.cpp
+++ b/torch_modules/src/LossFunction.cpp
@@ -1,5 +1,6 @@
 #include "LossFunction.hpp"
 #include "util.hpp"
+#include "NeuralNetwork.hpp"
 
 void LossFunction::init(std::string name)
 {
@@ -50,13 +51,13 @@ torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::
 
   if (index == 0 or index == 2 or index == 4)
   {
-    auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
+    auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
     gold[0] = goldIndexes.at(0);
     return gold;
   }
   if (index == 1 or index == 3)
   {
-    auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong));
+    auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
     for (auto goldIndex : goldIndexes)
       gold[goldIndex] = 1;
     return gold;
-- 
GitLab