diff --git a/torch_modules/include/LossFunction.hpp b/torch_modules/include/LossFunction.hpp index b845ab3104fc7af65ec05072956f52566be4b8c3..b6dfe522b108e5b46ee467a8c4036d7e601c89d4 100644 --- a/torch_modules/include/LossFunction.hpp +++ b/torch_modules/include/LossFunction.hpp @@ -10,7 +10,7 @@ class LossFunction private : std::string name{"_undefined_loss_"}; - std::variant<torch::nn::CrossEntropyLoss, torch::nn::BCELoss, torch::nn::MSELoss, CustomHingeLoss> fct; + std::variant<torch::nn::CrossEntropyLoss, torch::nn::BCELoss, torch::nn::MSELoss, CustomHingeLoss, torch::nn::L1Loss> fct; public : diff --git a/torch_modules/src/LossFunction.cpp b/torch_modules/src/LossFunction.cpp index d39203b11978d6e26c01ae5ee1c1c524fb9dad67..037716d6c0e562dee6e77db3e49996d21d37789a 100644 --- a/torch_modules/src/LossFunction.cpp +++ b/torch_modules/src/LossFunction.cpp @@ -10,7 +10,9 @@ void LossFunction::init(std::string name) else if (util::lower(name) == "bce") fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean)); else if (util::lower(name) == "mse") - fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kMean)); + fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kSum)); + else if (util::lower(name) == "l1") + fct = torch::nn::L1Loss(torch::nn::L1LossOptions().reduction(torch::kSum)); else if (util::lower(name) == "hinge") fct = CustomHingeLoss(); else