diff --git a/torch_modules/src/LossFunction.cpp b/torch_modules/src/LossFunction.cpp
index e90b9014ffedd00c7823d829f55ce506d6bc28c6..2f8f1be9dff754f407c5d3e4bdf5fa14e6cf825a 100644
--- a/torch_modules/src/LossFunction.cpp
+++ b/torch_modules/src/LossFunction.cpp
@@ -1,5 +1,6 @@
 #include "LossFunction.hpp"
 #include "util.hpp"
+#include "NeuralNetwork.hpp"
 
 void LossFunction::init(std::string name)
 {
@@ -50,13 +51,13 @@ torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::
 
   if (index == 0 or index == 2 or index == 4)
   {
-    auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
+    auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
     gold[0] = goldIndexes.at(0);
     return gold;
   }
   if (index == 1 or index == 3)
   {
-    auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong));
+    auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
     for (auto goldIndex : goldIndexes)
       gold[goldIndex] = 1;
     return gold;