From 363424441350164800e66a863bbeee8d7f52731e Mon Sep 17 00:00:00 2001
From: "maxime.petit" <maxime.petit@sms.liscluster>
Date: Tue, 1 Jun 2021 14:48:42 +0200
Subject: [PATCH] Added new reward func

---
 Rl.py | 30 ++++++++++++++++++++++++++++++
 1 file changed, 30 insertions(+)

diff --git a/Rl.py b/Rl.py
index 8641f6c..f870bcc 100644
--- a/Rl.py
+++ b/Rl.py
@@ -2,6 +2,7 @@ import sys
 import random
 import torch
 import torch.nn.functional as F
+import numpy as np
 from Util import getDevice
 
 ################################################################################
@@ -152,3 +153,32 @@ def rewardE(appliable, config, action, missingLinks):
   return reward
 ################################################################################
 
+################################################################################
+def rewardF(appliable, config, action, missingLinks):
+  if appliable:
+    if "BACK" not in action.name :
+      reward = -1.0*action.getOracleScore(config, missingLinks)
+    else :
+      back = action.size
+      error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0]
+      last_error = error_in_pop[-1] if len(error_in_pop) > 0 else 0
+      reward = last_error - back
+  else:
+    reward = -3.0
+  return 10*reward
+################################################################################
+
+################################################################################
+def rewardG(appliable, config, action, missingLinks):
+  if appliable:
+    if "BACK" not in action.name :
+      reward = -action.getOracleScore(config, missingLinks)
+    else :
+      back = action.size
+      canceledRewards = [h[3] for h in config.historyPop[-back:]]
+      reward = np.log(1-sum(canceledRewards)) if -sum(canceledRewards) > 0 else -1
+  else:
+    reward = -3.0
+  return reward
+################################################################################
+
-- 
GitLab