#include "CustomHingeLoss.hpp" #include "NeuralNetwork.hpp" torch::Tensor CustomHingeLoss::operator()(torch::Tensor prediction, torch::Tensor gold) { torch::Tensor loss = torch::zeros(1).to(NeuralNetworkImpl::device); for (unsigned int i = 0; i < prediction.size(0); i++) { loss += torch::relu(1 - torch::max(gold[i]*prediction[i]) + torch::max((1-gold[i])*prediction[i])); } loss /= prediction.size(0); return loss; }