import random import torch import torch.nn.functional as F ################################################################################ class ReplayMemory(object): def __init__(self, capacity): self.capacity = capacity self.memory = [] self.position = 0 def push(self, transition): """Saves a transition.""" if len(self.memory) < self.capacity: self.memory.append(None) self.memory[self.position] = transition self.position = (self.position + 1) % self.capacity def sample(self, batch_size): return random.sample(self.memory, batch_size) def __len__(self): return len(self.memory) ################################################################################ ################################################################################ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle) : sample = random.random() if sample < probaRandom : candidates = [trans for trans in ts if trans.appliable(config)] return candidates[random.randrange(len(candidates))] if len(candidates) > 0 else None elif sample < probaRandom+probaOracle : candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)]) return candidates[0][1] if len(candidates) > 0 else None else : with torch.no_grad() : output = network(torch.stack([state])) candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index]] for index in range(len(ts))])[::-1] candidates = [cand[2] for cand in candidates if cand[0]] return candidates[0] if len(candidates) > 0 else None ################################################################################ ################################################################################ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) : gamma = 0.999 if len(memory) < batchSize : return 0.0 batch = memory.sample(batchSize) states = torch.stack([b[0] for b in batch]) actions = torch.stack([b[1] for b in batch]) next_states = torch.stack([b[2] for b in batch]) rewards = torch.stack([b[3] for b in batch]) predictedQ = policy_net(states).gather(1, actions) nextQ = target_net(next_states).max(1)[0].unsqueeze(0) nextQ = torch.transpose(nextQ, 0, 1) expectedReward = gamma*nextQ + rewards loss = F.smooth_l1_loss(predictedQ, expectedReward) optimizer.zero_grad() loss.backward() optimizer.step() return float(loss) ################################################################################