Select Git revision
-
Baptiste Bauvin authoredBaptiste Bauvin authored
Rl.py 8.24 KiB
import sys
import random
import torch
import torch.nn.functional as F
import numpy as np
from Util import getDevice
################################################################################
class ReplayMemory() :
def __init__(self, capacity, stateSize, fromState, toState) :
self.fromState = fromState
self.toState = toState
self.capacity = capacity
self.states = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
self.actions = torch.zeros(capacity, 1, dtype=torch.long, device=getDevice())
self.rewards = torch.zeros(capacity, 1, device=getDevice())
self.noNewStates = torch.zeros(capacity, dtype=torch.bool, device=getDevice())
self.position = 0
self.nbPushed = 0
def push(self, state, action, newState, reward) :
self.states[self.position] = state
self.actions[self.position] = action
if newState is not None :
self.newStates[self.position] = newState
self.noNewStates[self.position] = newState is None
self.rewards[self.position] = reward
self.position = (self.position + 1) % self.capacity
self.nbPushed += 1
def sample(self, batchSize) :
start = random.randint(0, len(self)-batchSize)
end = start+batchSize
return self.states[start:end], self.actions[start:end], self.newStates[start:end], self.noNewStates[start:end], self.rewards[start:end]
def __len__(self):
return min(self.nbPushed, self.capacity)
################################################################################
################################################################################
def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle, fromState) :
sample = random.random()
if sample < probaRandom :
return ts[random.randrange(len(ts))]
elif sample < probaRandom+probaOracle :
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)])
return candidates[0][1] if len(candidates) > 0 else None
else :
with torch.no_grad() :
network.setState(fromState)
output = network(torch.stack([state]))
predIndex = int(torch.argmax(output))
return ts[predIndex]
################################################################################
################################################################################
def optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) :
#lossFct = torch.nn.MSELoss()
#lossFct = torch.nn.L1Loss()
lossFct = torch.nn.SmoothL1Loss()
totalLoss = 0.0
for fromState in range(len(memory)) :
for toState in range(len(memory[fromState])) :
if memory[fromState][toState].nbPushed < batchSize :
continue
states, actions, nextStates, noNextStates, rewards = memory[fromState][toState].sample(batchSize)
policy_net.setState(fromState)
target_net.setState(toState)
predictedQ = policy_net(states).gather(1, actions)
nextQ = target_net(nextStates).max(1)[0].detach().unsqueeze(0)
nextQ = torch.transpose(nextQ, 0, 1)
nextQ[noNextStates] = 0.0
expectedReward = gamma*nextQ + rewards
loss = lossFct(predictedQ, expectedReward)
optimizer.zero_grad()
loss.backward()
for param in policy_net.parameters() :
if param.grad is not None :
param.grad.data.clamp_(-1, 1)
optimizer.step()
totalLoss += float(loss)
return totalLoss
################################################################################
################################################################################
def rewarding(appliable, config, action, missingLinks, funcname):
return globals()["reward"+funcname](appliable, config, action, missingLinks)
################################################################################
forbiddenReward = 1.5
################################################################################
def rewardE(appliable, config, action, missingLinks):
if appliable:
if action.name != "BACK" :
reward = -action.getOracleScore(config, missingLinks)
else :
reward = 0.5
else:
reward = -forbiddenReward
return reward
################################################################################
################################################################################
def rewardG(appliable, config, action, missingLinks):
if appliable:
if action.name != "BACK" :
reward = -action.getOracleScore(config, missingLinks)
else :
back = action.size
canceledRewards = [h[3] for h in config.historyPop[-back:]]
reward = np.log(1-sum(canceledRewards)) if -sum(canceledRewards) > 0 else -1
else:
reward = -forbiddenReward
return reward
################################################################################
################################################################################
def rewardA(appliable, config, action, missingLinks):
if appliable:
if action.name != "BACK" :
reward = -action.getOracleScore(config, missingLinks)
else :
canceledRewards = []
found = 0
for i in range(len(config.historyPop))[::-1] :
if config.historyPop[i][0].name == "NOBACK" :
found += 1
if found == action.size :
break
else :
canceledRewards.append(config.historyPop[i][3])
reward = np.log(1-sum(canceledRewards)) if -sum(canceledRewards) > 0 else -1
else:
reward = -forbiddenReward
return reward
################################################################################
################################################################################
def rewardB(appliable, config, action, missingLinks):
if appliable:
if action.name != "BACK" :
reward = -action.getOracleScore(config, missingLinks)
else :
canceledRewards = []
found = 0
for i in range(len(config.historyPop))[::-1] :
if config.historyPop[i][0].name == "NOBACK" :
found += 1
if found == action.size :
break
else :
canceledRewards.append(config.historyPop[i][3])
reward = np.log(1-sum(canceledRewards)) if -sum(canceledRewards) > 0 else -1
else:
reward = -forbiddenReward
return (1.0 if config.nbUndone == 0 else 2.0)*reward
################################################################################
################################################################################
def rewardA2(appliable, config, action, missingLinks):
if appliable:
if action.name != "BACK" :
reward = -action.getOracleScore(config, missingLinks)
else :
canceledRewards = []
found = 0
for i in range(len(config.historyPop))[::-1] :
if config.historyPop[i][0].name == "NOBACK" :
found += 1
if found == action.size :
break
else :
canceledRewards.append(config.historyPop[i][3])
reward = np.log(1-sum(canceledRewards)) / 2 if -sum(canceledRewards) > 0 else -1
else:
reward = -forbiddenReward
return reward
################################################################################
################################################################################
def reward3G(appliable, config, action, missingLinks):
if appliable:
if action.name != "BACK" :
reward = -action.getOracleScore(config, missingLinks)
else :
back = action.size
canceledRewards = [h[3] for h in config.historyPop[-back:]]
reward = np.log(1-sum(canceledRewards)) if -sum(canceledRewards) > 0 else -1
else:
reward = -forbiddenReward
return 3*reward
################################################################################
################################################################################
def reward10G(appliable, config, action, missingLinks):
if appliable:
if action.name != "BACK" :
reward = -action.getOracleScore(config, missingLinks)
else :
back = action.size
canceledRewards = [h[3] for h in config.historyPop[-back:]]
reward = np.log(1-sum(canceledRewards)) if -sum(canceledRewards) > 0 else -1
else:
reward = -forbiddenReward
return 10*reward
################################################################################