Skip to content
Snippets Groups Projects
CustomHingeLoss.cpp 466 B
Newer Older
#include "NeuralNetwork.hpp"

torch::Tensor CustomHingeLoss::operator()(torch::Tensor prediction, torch::Tensor gold)
{
  torch::Tensor loss = torch::zeros(1).to(NeuralNetworkImpl::device);

  for (unsigned int i = 0; i < prediction.size(0); i++)
  {
    loss += torch::relu(1 - torch::max(gold[i]*prediction[i])
                          + torch::max((1-gold[i])*prediction[i]));
  }

  loss /= prediction.size(0);

  return loss;
}