Skip to content
Snippets Groups Projects
Rl.py 3.08 KiB
Newer Older
  • Learn to ignore specific revisions
  • Maxime Petit's avatar
    Maxime Petit committed
    import random
    import torch
    
    Franck Dary's avatar
    Franck Dary committed
    import torch.nn.functional as F
    
    Maxime Petit's avatar
    Maxime Petit committed
    
    ################################################################################
    
    Franck Dary's avatar
    Franck Dary committed
    class ReplayMemory() :
      def __init__(self, capacity, stateSize) :
    
    Franck Dary's avatar
    Franck Dary committed
        self.capacity = capacity
    
        self.states = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
        self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
        self.actions = torch.zeros(capacity, 1, dtype=torch.long, device=getDevice())
        self.rewards = torch.zeros(capacity, 1, device=getDevice())
    
    Franck Dary's avatar
    Franck Dary committed
        self.position = 0
    
    Franck Dary's avatar
    Franck Dary committed
        self.nbPushed = 0
    
    Maxime Petit's avatar
    Maxime Petit committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      def push(self, state, action, newState, reward) :
        self.states[self.position] = state
        self.actions[self.position] = action
        self.newStates[self.position] = newState
        self.rewards[self.position] = reward 
    
    Franck Dary's avatar
    Franck Dary committed
        self.position = (self.position + 1) % self.capacity
    
    Franck Dary's avatar
    Franck Dary committed
        self.nbPushed += 1
    
    Maxime Petit's avatar
    Maxime Petit committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      def sample(self, batchSize) :
        start = random.randint(0, len(self)-batchSize)
        end = start+batchSize
        return self.states[start:end], self.actions[start:end], self.newStates[start:end], self.rewards[start:end]
    
    Maxime Petit's avatar
    Maxime Petit committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      def __len__(self):
    
    Franck Dary's avatar
    Franck Dary committed
        return min(self.nbPushed, self.capacity)
    
    Franck Dary's avatar
    Franck Dary committed
    ################################################################################
    
    Maxime Petit's avatar
    Maxime Petit committed
    
    
    Franck Dary's avatar
    Franck Dary committed
    ################################################################################
    def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle) :
      sample = random.random()
      if sample < probaRandom :
        candidates = [trans for trans in ts if trans.appliable(config)]
        return candidates[random.randrange(len(candidates))] if len(candidates) > 0 else None
      elif sample < probaRandom+probaOracle :
        candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)])
        return candidates[0][1] if len(candidates) > 0 else None
      else :
        with torch.no_grad() :
          output = network(torch.stack([state]))
          candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index]] for index in range(len(ts))])[::-1]
          candidates = [cand[2] for cand in candidates if cand[0]]
          return candidates[0] if len(candidates) > 0 else None
    
    Maxime Petit's avatar
    Maxime Petit committed
    ################################################################################
    
    ################################################################################
    
    Franck Dary's avatar
    Franck Dary committed
    def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
      gamma = 0.999
      if len(memory) < batchSize :
        return 0.0
    
    Franck Dary's avatar
    Franck Dary committed
      states, actions, nextStates, rewards = memory.sample(batchSize)
    
    Maxime Petit's avatar
    Maxime Petit committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      predictedQ = policy_net(states).gather(1, actions)
    
      nextQ = target_net(nextStates).max(1)[0].detach().unsqueeze(0)
    
    Franck Dary's avatar
    Franck Dary committed
      nextQ = torch.transpose(nextQ, 0, 1)
    
    Maxime Petit's avatar
    Maxime Petit committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      expectedReward = gamma*nextQ + rewards
    
    Maxime Petit's avatar
    Maxime Petit committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      loss = F.smooth_l1_loss(predictedQ, expectedReward)
      optimizer.zero_grad()
      loss.backward()
    
      for param in policy_net.parameters() :
        if param.grad is not None :
          param.grad.data.clamp_(-1, 1)
    
    Franck Dary's avatar
    Franck Dary committed
      optimizer.step()
    
    Franck Dary's avatar
    Franck Dary committed
      return float(loss)
    ################################################################################
    
    Maxime Petit's avatar
    Maxime Petit committed