From 01f7d410741fe7606eae9ed142dac55ed1b71759 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sun, 17 Jan 2021 19:05:18 +0100
Subject: [PATCH] All losses are reduced to sum instead of mean (to give
 consistent values regardless of batch size)

---
 torch_modules/src/LossFunction.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/torch_modules/src/LossFunction.cpp b/torch_modules/src/LossFunction.cpp
index a162031..e90b901 100644
--- a/torch_modules/src/LossFunction.cpp
+++ b/torch_modules/src/LossFunction.cpp
@@ -6,9 +6,9 @@ void LossFunction::init(std::string name)
   this->name = name;
 
   if (util::lower(name) == "crossentropy")
-    fct = torch::nn::CrossEntropyLoss(torch::nn::CrossEntropyLossOptions().reduction(torch::kMean));
+    fct = torch::nn::CrossEntropyLoss(torch::nn::CrossEntropyLossOptions().reduction(torch::kSum));
   else if (util::lower(name) == "bce")
-    fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean));
+    fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kSum));
   else if (util::lower(name) == "mse")
     fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kSum));
   else if (util::lower(name) == "l1")
-- 
GitLab