-
Franck Dary authoredFranck Dary authored
LossFunction.cpp 2.05 KiB
#include "LossFunction.hpp"
#include "util.hpp"
void LossFunction::init(std::string name)
{
this->name = name;
if (util::lower(name) == "crossentropy")
fct = torch::nn::CrossEntropyLoss(torch::nn::CrossEntropyLossOptions().reduction(torch::kMean));
else if (util::lower(name) == "bce")
fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean));
else if (util::lower(name) == "mse")
fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kSum));
else if (util::lower(name) == "l1")
fct = torch::nn::L1Loss(torch::nn::L1LossOptions().reduction(torch::kSum));
else if (util::lower(name) == "hinge")
fct = CustomHingeLoss();
else
util::myThrow(fmt::format("unknown loss function name '{}' available losses are 'crossentropy, bce, mse, hinge'", name));
}
torch::Tensor LossFunction::operator()(torch::Tensor prediction, torch::Tensor gold)
{
try
{
auto index = fct.index();
if (index == 0)
return std::get<0>(fct)(prediction, gold.reshape(gold.dim() == 0 ? 1 : gold.size(0)));
if (index == 1)
return std::get<1>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat));
if (index == 2)
return std::get<2>(fct)(prediction, gold);
if (index == 3)
return std::get<3>(fct)(torch::softmax(prediction, 1), gold);
} catch (std::exception & e)
{
util::myThrow(fmt::format("computing loss '{}' caught '{}'", name, e.what()));
}
util::myThrow("loss is not defined");
return torch::Tensor();
}
torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::vector<long> & goldIndexes) const
{
auto index = fct.index();
if (index == 0 or index == 2)
{
auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
gold[0] = goldIndexes.at(0);
return gold;
}
if (index == 1 or index == 3)
{
auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong));
for (auto goldIndex : goldIndexes)
gold[goldIndex] = 1;
return gold;
}
util::myThrow("loss is not defined");
return torch::Tensor();
}