diff --git a/Rl.py b/Rl.py index fec4da33f276a0bcf82a9f3219903c6616a089fe..8382e315862758b99f174cbeb16e55601daa53cc 100644 --- a/Rl.py +++ b/Rl.py @@ -3,24 +3,31 @@ import torch import torch.nn.functional as F ################################################################################ -class ReplayMemory(object): - def __init__(self, capacity): +class ReplayMemory() : + def __init__(self, capacity, stateSize) : self.capacity = capacity - self.memory = [] + self.states = torch.zeros(capacity, stateSize, dtype=torch.long) + self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long) + self.actions = torch.zeros(capacity, 1, dtype=torch.long) + self.rewards = torch.zeros(capacity, 1) self.position = 0 + self.nbPushed = 0 - def push(self, transition): - """Saves a transition.""" - if len(self.memory) < self.capacity: - self.memory.append(None) - self.memory[self.position] = transition + def push(self, state, action, newState, reward) : + self.states[self.position] = state + self.actions[self.position] = action + self.newStates[self.position] = newState + self.rewards[self.position] = reward self.position = (self.position + 1) % self.capacity + self.nbPushed += 1 - def sample(self, batch_size): - return random.sample(self.memory, batch_size) + def sample(self, batchSize) : + start = random.randint(0, len(self)-batchSize) + end = start+batchSize + return self.states[start:end], self.actions[start:end], self.newStates[start:end], self.rewards[start:end] def __len__(self): - return len(self.memory) + return min(self.nbPushed, self.capacity) ################################################################################ ################################################################################ @@ -47,14 +54,10 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) : if len(memory) < batchSize : return 0.0 - batch = memory.sample(batchSize) - states = torch.stack([b[0] for b in batch]) - actions = torch.stack([b[1] for b in batch]) - next_states = torch.stack([b[2] for b in batch]) - rewards = torch.stack([b[3] for b in batch]) + states, actions, nextStates, rewards = memory.sample(batchSize) predictedQ = policy_net(states).gather(1, actions) - nextQ = target_net(next_states).max(1)[0].unsqueeze(0) + nextQ = target_net(nextStates).max(1)[0].unsqueeze(0) nextQ = torch.transpose(nextQ, 0, 1) expectedReward = gamma*nextQ + rewards diff --git a/Train.py b/Train.py index b61ecd39aafd4133f6ff3c013a55991af36f3e3f..838a71aac6309fd6d2e377370f8e3df7eaaefe96 100644 --- a/Train.py +++ b/Train.py @@ -128,20 +128,15 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr ################################################################################ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, silent=False) : - memory = ReplayMemory(1000) + memory = None dicts = Dicts() dicts.readConllu(filename, ["FORM", "UPOS"]) dicts.save(modelDir + "/dicts.json") - policy_net = Networks.BaseNet(dicts, 13, len(transitionSet)) - target_net = Networks.BaseNet(dicts, 13, len(transitionSet)) - target_net.load_state_dict(policy_net.state_dict()) - target_net.eval() - policy_net.train() + policy_net = None + target_net = None + optimizer = None - print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr) - - optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001) bestLoss = None bestScore = None @@ -155,6 +150,16 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti sentence = sentences[sentIndex] sentence.moveWordIndex(0) state = Features.extractFeaturesPosExtended(dicts, sentence) + + if policy_net is None : + policy_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet)) + target_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet)) + target_net.load_state_dict(policy_net.state_dict()) + target_net.eval() + policy_net.train() + optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001) + print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr) + while True : missingLinks = getMissingLinks(sentence) if debug : @@ -169,7 +174,9 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti applyTransition(transitionSet, strategy, sentence, action.name) newState = Features.extractFeaturesPosExtended(dicts, sentence) - memory.push((state, torch.LongTensor([transitionSet.index(action)]), newState, reward)) + if memory is None : + memory = ReplayMemory(1000, state.numel()) + memory.push(state, torch.LongTensor([transitionSet.index(action)]), newState, reward) state = newState if i % batchSize == 0 : totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer)