diff --git a/torch_modules/src/LossFunction.cpp b/torch_modules/src/LossFunction.cpp index a162031d47339332519d7d8a4750f8d50b437a37..e90b9014ffedd00c7823d829f55ce506d6bc28c6 100644 --- a/torch_modules/src/LossFunction.cpp +++ b/torch_modules/src/LossFunction.cpp @@ -6,9 +6,9 @@ void LossFunction::init(std::string name) this->name = name; if (util::lower(name) == "crossentropy") - fct = torch::nn::CrossEntropyLoss(torch::nn::CrossEntropyLossOptions().reduction(torch::kMean)); + fct = torch::nn::CrossEntropyLoss(torch::nn::CrossEntropyLossOptions().reduction(torch::kSum)); else if (util::lower(name) == "bce") - fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean)); + fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kSum)); else if (util::lower(name) == "mse") fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kSum)); else if (util::lower(name) == "l1")