Newer
Older
################################################################################
class ReplayMemory(object):
def __init__(self, capacity):
self.capacity = capacity
self.memory = []
self.position = 0
def push(self, transition):
"""Saves a transition."""
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = transition
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
################################################################################
################################################################################
def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle) :
sample = random.random()
if sample < probaRandom :
candidates = [trans for trans in ts if trans.appliable(config)]
return candidates[random.randrange(len(candidates))] if len(candidates) > 0 else None
elif sample < probaRandom+probaOracle :
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)])
return candidates[0][1] if len(candidates) > 0 else None
else :
with torch.no_grad() :
output = network(torch.stack([state]))
candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index]] for index in range(len(ts))])[::-1]
candidates = [cand[2] for cand in candidates if cand[0]]
return candidates[0] if len(candidates) > 0 else None
################################################################################
################################################################################
def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
gamma = 0.999
if len(memory) < batchSize :
return 0.0
batch = memory.sample(batchSize)
states = torch.stack([b[0] for b in batch])
actions = torch.stack([b[1] for b in batch])
next_states = torch.stack([b[2] for b in batch])
rewards = torch.stack([b[3] for b in batch])
predictedQ = policy_net(states).gather(1, actions)
nextQ = target_net(next_states).max(1)[0].unsqueeze(0)
nextQ = torch.transpose(nextQ, 0, 1)
loss = F.smooth_l1_loss(predictedQ, expectedReward)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return float(loss)
################################################################################