From 01f7d410741fe7606eae9ed142dac55ed1b71759 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sun, 17 Jan 2021 19:05:18 +0100 Subject: [PATCH] All losses are reduced to sum instead of mean (to give consistent values regardless of batch size) --- torch_modules/src/LossFunction.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_modules/src/LossFunction.cpp b/torch_modules/src/LossFunction.cpp index a162031..e90b901 100644 --- a/torch_modules/src/LossFunction.cpp +++ b/torch_modules/src/LossFunction.cpp @@ -6,9 +6,9 @@ void LossFunction::init(std::string name) this->name = name; if (util::lower(name) == "crossentropy") - fct = torch::nn::CrossEntropyLoss(torch::nn::CrossEntropyLossOptions().reduction(torch::kMean)); + fct = torch::nn::CrossEntropyLoss(torch::nn::CrossEntropyLossOptions().reduction(torch::kSum)); else if (util::lower(name) == "bce") - fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean)); + fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kSum)); else if (util::lower(name) == "mse") fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kSum)); else if (util::lower(name) == "l1") -- GitLab