#include "LossFunction.hpp" #include "util.hpp" void LossFunction::init(std::string name) { this->name = name; if (util::lower(name) == "crossentropy") fct = torch::nn::CrossEntropyLoss(torch::nn::CrossEntropyLossOptions().reduction(torch::kMean)); else if (util::lower(name) == "bce") fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean)); else if (util::lower(name) == "mse") fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kSum)); else if (util::lower(name) == "l1") fct = torch::nn::L1Loss(torch::nn::L1LossOptions().reduction(torch::kSum)); else if (util::lower(name) == "hinge") fct = CustomHingeLoss(); else util::myThrow(fmt::format("unknown loss function name '{}' available losses are 'crossentropy, bce, mse, hinge'", name)); } torch::Tensor LossFunction::operator()(torch::Tensor prediction, torch::Tensor gold) { try { auto index = fct.index(); if (index == 0) return std::get<0>(fct)(prediction, gold.reshape(gold.dim() == 0 ? 1 : gold.size(0))); if (index == 1) return std::get<1>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat)); if (index == 2) return std::get<2>(fct)(prediction, gold); if (index == 3) return std::get<3>(fct)(torch::softmax(prediction, 1), gold); } catch (std::exception & e) { util::myThrow(fmt::format("computing loss '{}' caught '{}'", name, e.what())); } util::myThrow("loss is not defined"); return torch::Tensor(); } torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::vector<long> & goldIndexes) const { auto index = fct.index(); if (index == 0 or index == 2) { auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong)); gold[0] = goldIndexes.at(0); return gold; } if (index == 1 or index == 3) { auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong)); for (auto goldIndex : goldIndexes) gold[goldIndex] = 1; return gold; } util::myThrow("loss is not defined"); return torch::Tensor(); }