From ed5ae141b5f110f56f7aafdef59ce533d166ad76 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 15 Jan 2021 22:35:58 +0100 Subject: [PATCH] Fixed L1 loss --- torch_modules/src/LossFunction.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_modules/src/LossFunction.cpp b/torch_modules/src/LossFunction.cpp index 037716d..a162031 100644 --- a/torch_modules/src/LossFunction.cpp +++ b/torch_modules/src/LossFunction.cpp @@ -33,6 +33,8 @@ torch::Tensor LossFunction::operator()(torch::Tensor prediction, torch::Tensor g return std::get<2>(fct)(prediction, gold); if (index == 3) return std::get<3>(fct)(torch::softmax(prediction, 1), gold); + if (index == 4) + return std::get<4>(fct)(prediction, gold); } catch (std::exception & e) { util::myThrow(fmt::format("computing loss '{}' caught '{}'", name, e.what())); @@ -46,7 +48,7 @@ torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std:: { auto index = fct.index(); - if (index == 0 or index == 2) + if (index == 0 or index == 2 or index == 4) { auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong)); gold[0] = goldIndexes.at(0); -- GitLab