Skip to content
Snippets Groups Projects
Commit 01f7d410 authored by Franck Dary's avatar Franck Dary
Browse files

All losses are reduced to sum instead of mean (to give consistent values regardless of batch size)

parent ed5ae141
No related branches found
No related tags found
No related merge requests found
......@@ -6,9 +6,9 @@ void LossFunction::init(std::string name)
this->name = name;
if (util::lower(name) == "crossentropy")
fct = torch::nn::CrossEntropyLoss(torch::nn::CrossEntropyLossOptions().reduction(torch::kMean));
fct = torch::nn::CrossEntropyLoss(torch::nn::CrossEntropyLossOptions().reduction(torch::kSum));
else if (util::lower(name) == "bce")
fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean));
fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kSum));
else if (util::lower(name) == "mse")
fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kSum));
else if (util::lower(name) == "l1")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment