diff --git a/torch_modules/src/LossFunction.cpp b/torch_modules/src/LossFunction.cpp index 037716d6c0e562dee6e77db3e49996d21d37789a..a162031d47339332519d7d8a4750f8d50b437a37 100644 --- a/torch_modules/src/LossFunction.cpp +++ b/torch_modules/src/LossFunction.cpp @@ -33,6 +33,8 @@ torch::Tensor LossFunction::operator()(torch::Tensor prediction, torch::Tensor g return std::get<2>(fct)(prediction, gold); if (index == 3) return std::get<3>(fct)(torch::softmax(prediction, 1), gold); + if (index == 4) + return std::get<4>(fct)(prediction, gold); } catch (std::exception & e) { util::myThrow(fmt::format("computing loss '{}' caught '{}'", name, e.what())); @@ -46,7 +48,7 @@ torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std:: { auto index = fct.index(); - if (index == 0 or index == 2) + if (index == 0 or index == 2 or index == 4) { auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong)); gold[0] = goldIndexes.at(0);