diff --git a/torch_modules/src/LossFunction.cpp b/torch_modules/src/LossFunction.cpp index e90b9014ffedd00c7823d829f55ce506d6bc28c6..2f8f1be9dff754f407c5d3e4bdf5fa14e6cf825a 100644 --- a/torch_modules/src/LossFunction.cpp +++ b/torch_modules/src/LossFunction.cpp @@ -1,5 +1,6 @@ #include "LossFunction.hpp" #include "util.hpp" +#include "NeuralNetwork.hpp" void LossFunction::init(std::string name) { @@ -50,13 +51,13 @@ torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std:: if (index == 0 or index == 2 or index == 4) { - auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong)); + auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device)); gold[0] = goldIndexes.at(0); return gold; } if (index == 1 or index == 3) { - auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong)); + auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device)); for (auto goldIndex : goldIndexes) gold[goldIndex] = 1; return gold;