diff --git a/Rl.py b/Rl.py index 4767c30b89865d2dcc87506daa169f70b6f54173..e091df525f4e5805aae75bd9d437c5fff0b7bbe0 100644 --- a/Rl.py +++ b/Rl.py @@ -7,7 +7,9 @@ from Util import getDevice ################################################################################ class ReplayMemory() : - def __init__(self, capacity, stateSize, nbStates) : + def __init__(self, capacity, stateSize, fromState, toState) : + self.fromState = fromState + self.toState = toState self.capacity = capacity self.states = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice()) self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice()) @@ -37,7 +39,7 @@ class ReplayMemory() : ################################################################################ ################################################################################ -def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle) : +def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle, fromState) : sample = random.random() if sample < probaRandom : return ts[random.randrange(len(ts))] @@ -46,6 +48,7 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra return candidates[0][1] if len(candidates) > 0 else None else : with torch.no_grad() : + network.setState(fromState) output = network(torch.stack([state])) predIndex = int(torch.argmax(output)) return ts[predIndex] @@ -53,27 +56,35 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra ################################################################################ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) : - if len(memory) < batchSize : - return 0.0 - - states, actions, nextStates, noNextStates, rewards = memory.sample(batchSize) - - predictedQ = policy_net(states).gather(1, actions) - nextQ = target_net(nextStates).max(1)[0].detach().unsqueeze(0) - nextQ = torch.transpose(nextQ, 0, 1) - nextQ[noNextStates] = 0.0 - - expectedReward = gamma*nextQ + rewards - - loss = F.smooth_l1_loss(predictedQ, expectedReward) - optimizer.zero_grad() - loss.backward() - for param in policy_net.parameters() : - if param.grad is not None : - param.grad.data.clamp_(-1, 1) - optimizer.step() - - return float(loss) + totalLoss = 0.0 + for fromState in range(len(memory)) : + for toState in range(len(memory[fromState])) : + if memory[fromState][toState].nbPushed < batchSize : + continue + + states, actions, nextStates, noNextStates, rewards = memory[fromState][toState].sample(batchSize) + + policy_net.setState(fromState) + target_net.setState(toState) + + predictedQ = policy_net(states).gather(1, actions) + nextQ = target_net(nextStates).max(1)[0].detach().unsqueeze(0) + nextQ = torch.transpose(nextQ, 0, 1) + nextQ[noNextStates] = 0.0 + + expectedReward = gamma*nextQ + rewards + + loss = F.smooth_l1_loss(predictedQ, expectedReward) + optimizer.zero_grad() + loss.backward() + for param in policy_net.parameters() : + if param.grad is not None : + param.grad.data.clamp_(-1, 1) + optimizer.step() + + totalLoss += float(loss) + + return totalLoss ################################################################################ ################################################################################ diff --git a/Train.py b/Train.py index de709df3a709fecb7917723722fac1a50fc48466..c6b57bf4ab85bfebebf16b33a5e1a29ae1eab2ec 100644 --- a/Train.py +++ b/Train.py @@ -235,9 +235,12 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti while True : missingLinks = getMissingLinks(sentence) transitionSet = transitionSets[sentence.state] + fromState = sentence.state + toState = sentence.state + if debug : sentence.printForDebug(sys.stderr) - action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle) + action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle, fromState) if action is None : break @@ -253,14 +256,14 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti newState = None if appliable : applyTransition(strategy, sentence, action, reward_) + toState = sentence.state newState = policy_net.extractFeatures(dicts, sentence).to(getDevice()) - else: count+=1 if memory is None : - memory = ReplayMemory(5000, state.numel()) - memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward) + memory = [[ReplayMemory(5000, state.numel(), f, t) for t in range(len(transitionSets))] for f in range(len(transitionSets))] + memory[fromState][toState].push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward) state = newState if i % batchSize == 0 : totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma)