From 04780fb6de9b520b93a8d39c0f4f7fccd332e027 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 15 Jan 2021 21:44:33 +0100 Subject: [PATCH] Added L1 loss and removed mean from regression losses --- torch_modules/include/LossFunction.hpp | 2 +- torch_modules/src/LossFunction.cpp | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/torch_modules/include/LossFunction.hpp b/torch_modules/include/LossFunction.hpp index b845ab3..b6dfe52 100644 --- a/torch_modules/include/LossFunction.hpp +++ b/torch_modules/include/LossFunction.hpp @@ -10,7 +10,7 @@ class LossFunction private : std::string name{"_undefined_loss_"}; - std::variant<torch::nn::CrossEntropyLoss, torch::nn::BCELoss, torch::nn::MSELoss, CustomHingeLoss> fct; + std::variant<torch::nn::CrossEntropyLoss, torch::nn::BCELoss, torch::nn::MSELoss, CustomHingeLoss, torch::nn::L1Loss> fct; public : diff --git a/torch_modules/src/LossFunction.cpp b/torch_modules/src/LossFunction.cpp index d39203b..037716d 100644 --- a/torch_modules/src/LossFunction.cpp +++ b/torch_modules/src/LossFunction.cpp @@ -10,7 +10,9 @@ void LossFunction::init(std::string name) else if (util::lower(name) == "bce") fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean)); else if (util::lower(name) == "mse") - fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kMean)); + fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kSum)); + else if (util::lower(name) == "l1") + fct = torch::nn::L1Loss(torch::nn::L1LossOptions().reduction(torch::kSum)); else if (util::lower(name) == "hinge") fct = CustomHingeLoss(); else -- GitLab