diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index f34f966ef669e94f8d1fb32ab6d73b5e54346d51..c583491f50cb80588cccac6a088e975a503c005f 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -11,9 +11,11 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
   private :
 
   static torch::Device device;
+  std::map<std::string, torch::Tensor> lossParameters;
 
   public :
 
+  torch::Tensor getLossParameter(std::string state);
   virtual torch::Tensor forward(torch::Tensor input, const std::string & state) = 0;
   virtual torch::Tensor extractContext(Config & config) = 0;
   virtual void registerEmbeddings(bool loadPretrained) = 0;
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
index 84c0e13998ab30a9524f28e4808492d8954a5b1e..a3c072390995ed828b3a7e0c06199769e97ec412 100644
--- a/torch_modules/src/ModularNetwork.cpp
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -81,7 +81,10 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
   mlp = register_module("mlp", MLP(currentOutputSize, mlpDef));
 
   for (auto & it : nbOutputsPerState)
+  {
     outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second)));
+    getLossParameter(it.first);
+  }
 }
 
 torch::Tensor ModularNetworkImpl::forward(torch::Tensor input, const std::string & state)
diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
index fe3727dc18827225b17bbf6db9133b7dd67ce54e..55a00d5538e35b1e1ee95010b98e2a8985208c3a 100644
--- a/torch_modules/src/NeuralNetwork.cpp
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -2,6 +2,14 @@
 
 torch::Device NeuralNetworkImpl::device(getPreferredDevice());
 
+torch::Tensor NeuralNetworkImpl::getLossParameter(std::string state)
+{
+  if (lossParameters.count(state) == 0)
+    lossParameters[state] = register_parameter(fmt::format("lossParam_{}", state), torch::ones(1, torch::TensorOptions().device(NeuralNetworkImpl::getDevice()).requires_grad(true)));
+
+  return lossParameters[state];
+}
+
 float NeuralNetworkImpl::entropy(torch::Tensor probabilities)
 {
   if (probabilities.dim() != 1)
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index da9fb730ca33d601758439495a8e1bc233f15109..8051cca2a01b2c877e4be944e1988e5f3ac573eb 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -247,7 +247,9 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
       labels /= util::float2longScale;
     }
 
-    auto loss = machine.getClassifier(state)->getLossMultiplier(state)*machine.getClassifier(state)->getLossFunction()(prediction, labels);
+    auto lossParameter = machine.getClassifier(state)->getNN()->getLossParameter(state);
+
+    auto loss = machine.getClassifier(state)->getLossMultiplier(state)*machine.getClassifier(state)->getLossFunction()(prediction, labels)*(1.0/torch::exp(lossParameter)) + lossParameter;
     float lossAsFloat = 0.0;
     try
     {