diff --git a/Rl.py b/Rl.py index 5f65adcc5768ca82d10bbf8cfcd2c7ca4a2b1498..8e320dfb49298ccbb5c07497089f64b74af96de4 100644 --- a/Rl.py +++ b/Rl.py @@ -56,6 +56,10 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra ################################################################################ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) : + #lossFct = torch.nn.MSELoss() + #lossFct = torch.nn.L1Loss() + lossFct = torch.nn.SmoothL1Loss() + totalLoss = 0.0 for fromState in range(len(memory)) : for toState in range(len(memory[fromState])) : @@ -74,7 +78,7 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) : expectedReward = gamma*nextQ + rewards - loss = F.smooth_l1_loss(predictedQ, expectedReward) + loss = lossFct(predictedQ, expectedReward) optimizer.zero_grad() loss.backward() for param in policy_net.parameters() :