Skip to content
Snippets Groups Projects
Commit ee3d2d5e authored by Franck Dary's avatar Franck Dary
Browse files

Corrected bug where tensor was not initialized to the correct device

parent 4487af1d
No related branches found
No related tags found
No related merge requests found
#include "LossFunction.hpp" #include "LossFunction.hpp"
#include "util.hpp" #include "util.hpp"
#include "NeuralNetwork.hpp"
void LossFunction::init(std::string name) void LossFunction::init(std::string name)
{ {
...@@ -50,13 +51,13 @@ torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std:: ...@@ -50,13 +51,13 @@ torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::
if (index == 0 or index == 2 or index == 4) if (index == 0 or index == 2 or index == 4)
{ {
auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong)); auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
gold[0] = goldIndexes.at(0); gold[0] = goldIndexes.at(0);
return gold; return gold;
} }
if (index == 1 or index == 3) if (index == 1 or index == 3)
{ {
auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong)); auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
for (auto goldIndex : goldIndexes) for (auto goldIndex : goldIndexes)
gold[goldIndex] = 1; gold[goldIndex] = 1;
return gold; return gold;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment