Skip to content
Snippets Groups Projects
LossFunction.cpp 2.05 KiB
#include "LossFunction.hpp"
#include "util.hpp"

void LossFunction::init(std::string name)
{
  this->name = name;

  if (util::lower(name) == "crossentropy")
    fct = torch::nn::CrossEntropyLoss(torch::nn::CrossEntropyLossOptions().reduction(torch::kMean));
  else if (util::lower(name) == "bce")
    fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean));
  else if (util::lower(name) == "mse")
    fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kSum));
  else if (util::lower(name) == "l1")
    fct = torch::nn::L1Loss(torch::nn::L1LossOptions().reduction(torch::kSum));
  else if (util::lower(name) == "hinge")
    fct = CustomHingeLoss();
  else
    util::myThrow(fmt::format("unknown loss function name '{}' available losses are 'crossentropy, bce, mse, hinge'", name));
}

torch::Tensor LossFunction::operator()(torch::Tensor prediction, torch::Tensor gold)
{
  try
  {
    auto index = fct.index();

    if (index == 0)
      return std::get<0>(fct)(prediction, gold.reshape(gold.dim() == 0 ? 1 : gold.size(0)));
    if (index == 1)
      return std::get<1>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat));
    if (index == 2)
      return std::get<2>(fct)(prediction, gold);
    if (index == 3)
      return std::get<3>(fct)(torch::softmax(prediction, 1), gold);
  } catch (std::exception & e)
  {
    util::myThrow(fmt::format("computing loss '{}' caught '{}'", name, e.what()));
  }

  util::myThrow("loss is not defined");
  return torch::Tensor();
}

torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::vector<long> & goldIndexes) const
{
  auto index = fct.index();

  if (index == 0 or index == 2)
  {
    auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
    gold[0] = goldIndexes.at(0);
    return gold;
  }
  if (index == 1 or index == 3)
  {
    auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong));
    for (auto goldIndex : goldIndexes)
      gold[goldIndex] = 1;
    return gold;
  }

  util::myThrow("loss is not defined");
  return torch::Tensor();
}