Skip to content
Snippets Groups Projects
Rl.py 2.55 KiB
Newer Older
Maxime Petit's avatar
Maxime Petit committed
import random
import torch
Franck Dary's avatar
Franck Dary committed
import torch.nn.functional as F
Maxime Petit's avatar
Maxime Petit committed

################################################################################
class ReplayMemory(object):
Franck Dary's avatar
Franck Dary committed
  def __init__(self, capacity):
    self.capacity = capacity
    self.memory = []
    self.position = 0
Maxime Petit's avatar
Maxime Petit committed

Franck Dary's avatar
Franck Dary committed
  def push(self, transition):
    """Saves a transition."""
    if len(self.memory) < self.capacity:
      self.memory.append(None)
    self.memory[self.position] = transition
    self.position = (self.position + 1) % self.capacity
Maxime Petit's avatar
Maxime Petit committed

Franck Dary's avatar
Franck Dary committed
  def sample(self, batch_size):
    return random.sample(self.memory, batch_size)
Maxime Petit's avatar
Maxime Petit committed

Franck Dary's avatar
Franck Dary committed
  def __len__(self):
    return len(self.memory)
################################################################################
Maxime Petit's avatar
Maxime Petit committed

Franck Dary's avatar
Franck Dary committed
################################################################################
def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle) :
  sample = random.random()
  if sample < probaRandom :
    candidates = [trans for trans in ts if trans.appliable(config)]
    return candidates[random.randrange(len(candidates))] if len(candidates) > 0 else None
  elif sample < probaRandom+probaOracle :
    candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)])
    return candidates[0][1] if len(candidates) > 0 else None
  else :
    with torch.no_grad() :
      output = network(torch.stack([state]))
      candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index]] for index in range(len(ts))])[::-1]
      candidates = [cand[2] for cand in candidates if cand[0]]
      return candidates[0] if len(candidates) > 0 else None
Maxime Petit's avatar
Maxime Petit committed

################################################################################

################################################################################
Franck Dary's avatar
Franck Dary committed
def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
  gamma = 0.999
  if len(memory) < batchSize :
    return 0.0
  
  batch = memory.sample(batchSize)
  states = torch.stack([b[0] for b in batch])
  actions = torch.stack([b[1] for b in batch])
  next_states = torch.stack([b[2] for b in batch])
  rewards = torch.stack([b[3] for b in batch])
Maxime Petit's avatar
Maxime Petit committed

Franck Dary's avatar
Franck Dary committed
  predictedQ = policy_net(states).gather(1, actions)
  nextQ = target_net(next_states).max(1)[0].unsqueeze(0)
  nextQ = torch.transpose(nextQ, 0, 1)
Maxime Petit's avatar
Maxime Petit committed

Franck Dary's avatar
Franck Dary committed
  expectedReward = gamma*nextQ + rewards
Maxime Petit's avatar
Maxime Petit committed

Franck Dary's avatar
Franck Dary committed
  loss = F.smooth_l1_loss(predictedQ, expectedReward)
  optimizer.zero_grad()
  loss.backward()

  optimizer.step()
  return float(loss)
################################################################################
Maxime Petit's avatar
Maxime Petit committed