Newer
Older
#include "CustomHingeLoss.hpp"
torch::Tensor CustomHingeLoss::operator()(torch::Tensor prediction, torch::Tensor gold)
{
torch::Tensor loss = torch::zeros(1);
for (unsigned int i = 0; i < prediction.size(0); i++)
{
loss += torch::relu(1 - torch::max(gold[i]*prediction[i])
+ torch::max((1-gold[i])*prediction[i]));
}
loss /= prediction.size(0);
return loss;
}