From ed5ae141b5f110f56f7aafdef59ce533d166ad76 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 15 Jan 2021 22:35:58 +0100
Subject: [PATCH] Fixed L1 loss

---
 torch_modules/src/LossFunction.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/torch_modules/src/LossFunction.cpp b/torch_modules/src/LossFunction.cpp
index 037716d..a162031 100644
--- a/torch_modules/src/LossFunction.cpp
+++ b/torch_modules/src/LossFunction.cpp
@@ -33,6 +33,8 @@ torch::Tensor LossFunction::operator()(torch::Tensor prediction, torch::Tensor g
       return std::get<2>(fct)(prediction, gold);
     if (index == 3)
       return std::get<3>(fct)(torch::softmax(prediction, 1), gold);
+    if (index == 4)
+      return std::get<4>(fct)(prediction, gold);
   } catch (std::exception & e)
   {
     util::myThrow(fmt::format("computing loss '{}' caught '{}'", name, e.what()));
@@ -46,7 +48,7 @@ torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::
 {
   auto index = fct.index();
 
-  if (index == 0 or index == 2)
+  if (index == 0 or index == 2 or index == 4)
   {
     auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
     gold[0] = goldIndexes.at(0);
-- 
GitLab