From f76a0bb73f4c9a303f3afed1e610f2cfdd9b41e4 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 7 Jul 2021 14:11:16 +0200
Subject: [PATCH] Loss

---
 Rl.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/Rl.py b/Rl.py
index 5f65adc..8e320df 100644
--- a/Rl.py
+++ b/Rl.py
@@ -56,6 +56,10 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra
 
 ################################################################################
 def optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) :
+  #lossFct = torch.nn.MSELoss()
+  #lossFct = torch.nn.L1Loss()
+  lossFct = torch.nn.SmoothL1Loss()
+
   totalLoss = 0.0
   for fromState in range(len(memory)) :
     for toState in range(len(memory[fromState])) :
@@ -74,7 +78,7 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) :
     
       expectedReward = gamma*nextQ + rewards
     
-      loss = F.smooth_l1_loss(predictedQ, expectedReward)
+      loss = lossFct(predictedQ, expectedReward)
       optimizer.zero_grad()
       loss.backward()
       for param in policy_net.parameters() :
-- 
GitLab