import sys import random import torch import torch.nn.functional as F from Util import getDevice ################################################################################ class ReplayMemory() : def __init__(self, capacity, stateSize) : self.capacity = capacity self.states = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice()) self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice()) self.actions = torch.zeros(capacity, 1, dtype=torch.long, device=getDevice()) self.rewards = torch.zeros(capacity, 1, device=getDevice()) self.noNewStates = torch.zeros(capacity, dtype=torch.bool, device=getDevice()) self.position = 0 self.nbPushed = 0 def push(self, state, action, newState, reward) : self.states[self.position] = state self.actions[self.position] = action if newState is not None : self.newStates[self.position] = newState self.noNewStates[self.position] = newState is None self.rewards[self.position] = reward self.position = (self.position + 1) % self.capacity self.nbPushed += 1 def sample(self, batchSize) : start = random.randint(0, len(self)-batchSize) end = start+batchSize return self.states[start:end], self.actions[start:end], self.newStates[start:end], self.noNewStates[start:end], self.rewards[start:end] def __len__(self): return min(self.nbPushed, self.capacity) ################################################################################ ################################################################################ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle) : sample = random.random() if sample < probaRandom : return ts[random.randrange(len(ts))] elif sample < probaRandom+probaOracle : candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)]) return candidates[0][1] if len(candidates) > 0 else None else : with torch.no_grad() : output = network(torch.stack([state])) predIndex = int(torch.argmax(output)) return ts[predIndex] ################################################################################ ################################################################################ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) : if len(memory) < batchSize : return 0.0 states, actions, nextStates, noNextStates, rewards = memory.sample(batchSize) predictedQ = policy_net(states).gather(1, actions) nextQ = target_net(nextStates).max(1)[0].detach().unsqueeze(0) nextQ = torch.transpose(nextQ, 0, 1) nextQ[noNextStates] = 0.0 expectedReward = gamma*nextQ + rewards loss = F.smooth_l1_loss(predictedQ, expectedReward) optimizer.zero_grad() loss.backward() for param in policy_net.parameters() : if param.grad is not None : param.grad.data.clamp_(-1, 1) optimizer.step() return float(loss) ################################################################################ ################################################################################ def rewarding(appliable, config, action, missingLinks, funcname): return globals()["reward"+funcname](appliable, config, action, missingLinks) ################################################################################ ################################################################################ def rewardA(appliable, config, action, missingLinks): if appliable: if "BACK" not in action.name : reward = -1.0*action.getOracleScore(config, missingLinks) else : back = int(action.name.split()[-1]) error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0] last_error = error_in_pop[-1] if len(error_in_pop) > 0 else 0 reward = last_error - back else: reward = -3.0 return reward ################################################################################ ################################################################################ def rewardB(appliable, config, action, missingLinks): if appliable: if "BACK" not in action.name : reward = 1.0 - action.getOracleScore(config, missingLinks) else : back = int(action.name.split()[-1]) error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0] last_error = error_in_pop[-1] if len(error_in_pop) > 0 else 0 reward = last_error - back else: reward = -3.0 return reward ################################################################################ ################################################################################ def rewardC(appliable, config, action, missingLinks): if appliable: if "BACK" not in action.name : reward = -action.getOracleScore(config, missingLinks) else : back = int(action.name.split()[-1]) error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0] canceledRewards = [h[3] for h in config.historyPop[-back:]] reward = -sum(canceledRewards) else: reward = -3.0 return reward ################################################################################ ################################################################################ def rewardD(appliable, config, action, missingLinks): if appliable: if "BACK" not in action.name : reward = -action.getOracleScore(config, missingLinks) else : back = int(action.name.split()[-1]) error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0] canceledRewards = [h[3] for h in config.historyPop[-back:]] reward = -sum(canceledRewards) - 1 else: reward = -3.0 return reward ################################################################################ ################################################################################ def rewardE(appliable, config, action, missingLinks): if appliable: if "BACK" not in action.name : reward = -action.getOracleScore(config, missingLinks) else : reward = -0.5 else: reward = -3.0 return reward ################################################################################