diff --git a/torch_modules/src/LossFunction.cpp b/torch_modules/src/LossFunction.cpp
index 037716d6c0e562dee6e77db3e49996d21d37789a..a162031d47339332519d7d8a4750f8d50b437a37 100644
--- a/torch_modules/src/LossFunction.cpp
+++ b/torch_modules/src/LossFunction.cpp
@@ -33,6 +33,8 @@ torch::Tensor LossFunction::operator()(torch::Tensor prediction, torch::Tensor g
       return std::get<2>(fct)(prediction, gold);
     if (index == 3)
       return std::get<3>(fct)(torch::softmax(prediction, 1), gold);
+    if (index == 4)
+      return std::get<4>(fct)(prediction, gold);
   } catch (std::exception & e)
   {
     util::myThrow(fmt::format("computing loss '{}' caught '{}'", name, e.what()));
@@ -46,7 +48,7 @@ torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::
 {
   auto index = fct.index();
 
-  if (index == 0 or index == 2)
+  if (index == 0 or index == 2 or index == 4)
   {
     auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
     gold[0] = goldIndexes.at(0);