Skip to content
Snippets Groups Projects
Rl.py 2.55 KiB
Newer Older
  • Learn to ignore specific revisions
  • Maxime Petit's avatar
    Maxime Petit committed
    import random
    import torch
    
    Franck Dary's avatar
    Franck Dary committed
    import torch.nn.functional as F
    
    Maxime Petit's avatar
    Maxime Petit committed
    
    ################################################################################
    class ReplayMemory(object):
    
    Franck Dary's avatar
    Franck Dary committed
      def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0
    
    Maxime Petit's avatar
    Maxime Petit committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      def push(self, transition):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
          self.memory.append(None)
        self.memory[self.position] = transition
        self.position = (self.position + 1) % self.capacity
    
    Maxime Petit's avatar
    Maxime Petit committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    Maxime Petit's avatar
    Maxime Petit committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      def __len__(self):
        return len(self.memory)
    ################################################################################
    
    Maxime Petit's avatar
    Maxime Petit committed
    
    
    Franck Dary's avatar
    Franck Dary committed
    ################################################################################
    def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle) :
      sample = random.random()
      if sample < probaRandom :
        candidates = [trans for trans in ts if trans.appliable(config)]
        return candidates[random.randrange(len(candidates))] if len(candidates) > 0 else None
      elif sample < probaRandom+probaOracle :
        candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)])
        return candidates[0][1] if len(candidates) > 0 else None
      else :
        with torch.no_grad() :
          output = network(torch.stack([state]))
          candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index]] for index in range(len(ts))])[::-1]
          candidates = [cand[2] for cand in candidates if cand[0]]
          return candidates[0] if len(candidates) > 0 else None
    
    Maxime Petit's avatar
    Maxime Petit committed
    
    ################################################################################
    
    ################################################################################
    
    Franck Dary's avatar
    Franck Dary committed
    def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
      gamma = 0.999
      if len(memory) < batchSize :
        return 0.0
      
      batch = memory.sample(batchSize)
      states = torch.stack([b[0] for b in batch])
      actions = torch.stack([b[1] for b in batch])
      next_states = torch.stack([b[2] for b in batch])
      rewards = torch.stack([b[3] for b in batch])
    
    Maxime Petit's avatar
    Maxime Petit committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      predictedQ = policy_net(states).gather(1, actions)
      nextQ = target_net(next_states).max(1)[0].unsqueeze(0)
      nextQ = torch.transpose(nextQ, 0, 1)
    
    Maxime Petit's avatar
    Maxime Petit committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      expectedReward = gamma*nextQ + rewards
    
    Maxime Petit's avatar
    Maxime Petit committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      loss = F.smooth_l1_loss(predictedQ, expectedReward)
      optimizer.zero_grad()
      loss.backward()
    
      optimizer.step()
      return float(loss)
    ################################################################################
    
    Maxime Petit's avatar
    Maxime Petit committed