From 04780fb6de9b520b93a8d39c0f4f7fccd332e027 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 15 Jan 2021 21:44:33 +0100
Subject: [PATCH] Added L1 loss and removed mean from regression losses

---
 torch_modules/include/LossFunction.hpp | 2 +-
 torch_modules/src/LossFunction.cpp     | 4 +++-
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/torch_modules/include/LossFunction.hpp b/torch_modules/include/LossFunction.hpp
index b845ab3..b6dfe52 100644
--- a/torch_modules/include/LossFunction.hpp
+++ b/torch_modules/include/LossFunction.hpp
@@ -10,7 +10,7 @@ class LossFunction
   private :
 
   std::string name{"_undefined_loss_"};
-  std::variant<torch::nn::CrossEntropyLoss, torch::nn::BCELoss, torch::nn::MSELoss, CustomHingeLoss> fct;
+  std::variant<torch::nn::CrossEntropyLoss, torch::nn::BCELoss, torch::nn::MSELoss, CustomHingeLoss, torch::nn::L1Loss> fct;
 
   public :
 
diff --git a/torch_modules/src/LossFunction.cpp b/torch_modules/src/LossFunction.cpp
index d39203b..037716d 100644
--- a/torch_modules/src/LossFunction.cpp
+++ b/torch_modules/src/LossFunction.cpp
@@ -10,7 +10,9 @@ void LossFunction::init(std::string name)
   else if (util::lower(name) == "bce")
     fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean));
   else if (util::lower(name) == "mse")
-    fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kMean));
+    fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kSum));
+  else if (util::lower(name) == "l1")
+    fct = torch::nn::L1Loss(torch::nn::L1LossOptions().reduction(torch::kSum));
   else if (util::lower(name) == "hinge")
     fct = CustomHingeLoss();
   else
-- 
GitLab