Commit ee3d2d5e authored by Franck Dary's avatar Franck Dary
Browse files

Corrected bug where tensor was not initialized to the correct device

parent 4487af1d
#include "LossFunction.hpp"
#include "util.hpp"
#include "NeuralNetwork.hpp"
void LossFunction::init(std::string name)
{
......@@ -50,13 +51,13 @@ torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::
if (index == 0 or index == 2 or index == 4)
{
auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
gold[0] = goldIndexes.at(0);
return gold;
}
if (index == 1 or index == 3)
{
auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong));
auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
for (auto goldIndex : goldIndexes)
gold[goldIndex] = 1;
return gold;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment