diff --git a/Rl.py b/Rl.py index 8641f6c75297ae8293fb8ed5fe9b936d95cf5aab..f870bcc29624ca364c1a338ac2dc78badd907162 100644 --- a/Rl.py +++ b/Rl.py @@ -2,6 +2,7 @@ import sys import random import torch import torch.nn.functional as F +import numpy as np from Util import getDevice ################################################################################ @@ -152,3 +153,32 @@ def rewardE(appliable, config, action, missingLinks): return reward ################################################################################ +################################################################################ +def rewardF(appliable, config, action, missingLinks): + if appliable: + if "BACK" not in action.name : + reward = -1.0*action.getOracleScore(config, missingLinks) + else : + back = action.size + error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0] + last_error = error_in_pop[-1] if len(error_in_pop) > 0 else 0 + reward = last_error - back + else: + reward = -3.0 + return 10*reward +################################################################################ + +################################################################################ +def rewardG(appliable, config, action, missingLinks): + if appliable: + if "BACK" not in action.name : + reward = -action.getOracleScore(config, missingLinks) + else : + back = action.size + canceledRewards = [h[3] for h in config.historyPop[-back:]] + reward = np.log(1-sum(canceledRewards)) if -sum(canceledRewards) > 0 else -1 + else: + reward = -3.0 + return reward +################################################################################ +