Newer
Older
Franck Dary
committed
import sys
Franck Dary
committed
from Util import getDevice
################################################################################
class ReplayMemory() :
def __init__(self, capacity, stateSize) :
Franck Dary
committed
self.states = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
self.actions = torch.zeros(capacity, 1, dtype=torch.long, device=getDevice())
self.rewards = torch.zeros(capacity, 1, device=getDevice())
self.noNewStates = torch.zeros(capacity, dtype=torch.bool, device=getDevice())
def push(self, state, action, newState, reward) :
self.states[self.position] = state
self.actions[self.position] = action
if newState is not None :
self.newStates[self.position] = newState
self.noNewStates[self.position] = newState is None
def sample(self, batchSize) :
start = random.randint(0, len(self)-batchSize)
end = start+batchSize
return self.states[start:end], self.actions[start:end], self.newStates[start:end], self.noNewStates[start:end], self.rewards[start:end]
################################################################################
################################################################################
def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle) :
sample = random.random()
if sample < probaRandom :
elif sample < probaRandom+probaOracle :
candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)])
return candidates[0][1] if len(candidates) > 0 else None
else :
with torch.no_grad() :
output = network(torch.stack([state]))
predIndex = int(torch.argmax(output))
return ts[predIndex]
################################################################################
################################################################################
def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
Franck Dary
committed
states, actions, nextStates, noNextStates, rewards = memory.sample(batchSize)
Franck Dary
committed
nextQ = target_net(nextStates).max(1)[0].detach().unsqueeze(0)
loss = F.smooth_l1_loss(predictedQ, expectedReward)
optimizer.zero_grad()
loss.backward()
Franck Dary
committed
for param in policy_net.parameters() :
if param.grad is not None :
param.grad.data.clamp_(-1, 1)
Franck Dary
committed
return float(loss)
################################################################################
################################################################################
def rewarding(appliable, config, action, missingLinks):
if appliable:
if "BACK" not in action.name :
reward = -1.0*action.getOracleScore(config, missingLinks)
else :
back = int(action.name.split()[-1])
error_in_pop = [i for i in range(1,back) if config.historyPop[-i][3] < 0]
last_error = error_in_pop[-1] if len(error_in_pop) > 0 else 0
reward = last_error - back
else:
reward = -3.0
return reward
################################################################################