From 666d485595dbd7563aeec47b48ce18dc6172b937 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Thu, 25 Nov 2021 14:35:52 +0100
Subject: [PATCH] Added learnable coefs for the loss

---
 torch_modules/include/NeuralNetwork.hpp | 2 ++
 torch_modules/src/ModularNetwork.cpp    | 3 +++
 torch_modules/src/NeuralNetwork.cpp     | 8 ++++++++
 trainer/src/Trainer.cpp                 | 4 +++-
 4 files changed, 16 insertions(+), 1 deletion(-)

diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index f34f966..c583491 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -11,9 +11,11 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
   private :
 
   static torch::Device device;
+  std::map<std::string, torch::Tensor> lossParameters;
 
   public :
 
+  torch::Tensor getLossParameter(std::string state);
   virtual torch::Tensor forward(torch::Tensor input, const std::string & state) = 0;
   virtual torch::Tensor extractContext(Config & config) = 0;
   virtual void registerEmbeddings(bool loadPretrained) = 0;
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
index 84c0e13..a3c0723 100644
--- a/torch_modules/src/ModularNetwork.cpp
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -81,7 +81,10 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
   mlp = register_module("mlp", MLP(currentOutputSize, mlpDef));
 
   for (auto & it : nbOutputsPerState)
+  {
     outputLayersPerState.emplace(it.first,register_module(fmt::format("output_{}",it.first), torch::nn::Linear(mlp->outputSize(), it.second)));
+    getLossParameter(it.first);
+  }
 }
 
 torch::Tensor ModularNetworkImpl::forward(torch::Tensor input, const std::string & state)
diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
index fe3727d..55a00d5 100644
--- a/torch_modules/src/NeuralNetwork.cpp
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -2,6 +2,14 @@
 
 torch::Device NeuralNetworkImpl::device(getPreferredDevice());
 
+torch::Tensor NeuralNetworkImpl::getLossParameter(std::string state)
+{
+  if (lossParameters.count(state) == 0)
+    lossParameters[state] = register_parameter(fmt::format("lossParam_{}", state), torch::ones(1, torch::TensorOptions().device(NeuralNetworkImpl::getDevice()).requires_grad(true)));
+
+  return lossParameters[state];
+}
+
 float NeuralNetworkImpl::entropy(torch::Tensor probabilities)
 {
   if (probabilities.dim() != 1)
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index da9fb73..8051cca 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -247,7 +247,9 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
       labels /= util::float2longScale;
     }
 
-    auto loss = machine.getClassifier(state)->getLossMultiplier(state)*machine.getClassifier(state)->getLossFunction()(prediction, labels);
+    auto lossParameter = machine.getClassifier(state)->getNN()->getLossParameter(state);
+
+    auto loss = machine.getClassifier(state)->getLossMultiplier(state)*machine.getClassifier(state)->getLossFunction()(prediction, labels)*(1.0/torch::exp(lossParameter)) + lossParameter;
     float lossAsFloat = 0.0;
     try
     {
-- 
GitLab