From f76a0bb73f4c9a303f3afed1e610f2cfdd9b41e4 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 7 Jul 2021 14:11:16 +0200 Subject: [PATCH] Loss --- Rl.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/Rl.py b/Rl.py index 5f65adc..8e320df 100644 --- a/Rl.py +++ b/Rl.py @@ -56,6 +56,10 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra ################################################################################ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) : + #lossFct = torch.nn.MSELoss() + #lossFct = torch.nn.L1Loss() + lossFct = torch.nn.SmoothL1Loss() + totalLoss = 0.0 for fromState in range(len(memory)) : for toState in range(len(memory[fromState])) : @@ -74,7 +78,7 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) : expectedReward = gamma*nextQ + rewards - loss = F.smooth_l1_loss(predictedQ, expectedReward) + loss = lossFct(predictedQ, expectedReward) optimizer.zero_grad() loss.backward() for param in policy_net.parameters() : -- GitLab