Skip to content
Snippets Groups Projects
Commit ed5ae141 authored by Franck Dary's avatar Franck Dary
Browse files

Fixed L1 loss

parent 04780fb6
No related branches found
No related tags found
No related merge requests found
......@@ -33,6 +33,8 @@ torch::Tensor LossFunction::operator()(torch::Tensor prediction, torch::Tensor g
return std::get<2>(fct)(prediction, gold);
if (index == 3)
return std::get<3>(fct)(torch::softmax(prediction, 1), gold);
if (index == 4)
return std::get<4>(fct)(prediction, gold);
} catch (std::exception & e)
{
util::myThrow(fmt::format("computing loss '{}' caught '{}'", name, e.what()));
......@@ -46,7 +48,7 @@ torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::
{
auto index = fct.index();
if (index == 0 or index == 2)
if (index == 0 or index == 2 or index == 4)
{
auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
gold[0] = goldIndexes.at(0);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment