Commit cf56fefc authored by Franck Dary's avatar Franck Dary
Browse files

Corrected bug on CustomHingeLoss where it would not run on gpu

parent 817bb319
#include "CustomHingeLoss.hpp"
#include "NeuralNetwork.hpp"
torch::Tensor CustomHingeLoss::operator()(torch::Tensor prediction, torch::Tensor gold)
{
torch::Tensor loss = torch::zeros(1);
torch::Tensor loss = torch::zeros(1).to(NeuralNetworkImpl::device);
for (unsigned int i = 0; i < prediction.size(0); i++)
{
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment