Skip to content
Snippets Groups Projects
Commit 18082daf authored by Franck Dary's avatar Franck Dary
Browse files

Improved Replay Memory

parent bc508cdb
No related branches found
No related tags found
No related merge requests found
...@@ -3,24 +3,31 @@ import torch ...@@ -3,24 +3,31 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
################################################################################ ################################################################################
class ReplayMemory(object): class ReplayMemory() :
def __init__(self, capacity): def __init__(self, capacity, stateSize) :
self.capacity = capacity self.capacity = capacity
self.memory = [] self.states = torch.zeros(capacity, stateSize, dtype=torch.long)
self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long)
self.actions = torch.zeros(capacity, 1, dtype=torch.long)
self.rewards = torch.zeros(capacity, 1)
self.position = 0 self.position = 0
self.nbPushed = 0
def push(self, transition): def push(self, state, action, newState, reward) :
"""Saves a transition.""" self.states[self.position] = state
if len(self.memory) < self.capacity: self.actions[self.position] = action
self.memory.append(None) self.newStates[self.position] = newState
self.memory[self.position] = transition self.rewards[self.position] = reward
self.position = (self.position + 1) % self.capacity self.position = (self.position + 1) % self.capacity
self.nbPushed += 1
def sample(self, batch_size): def sample(self, batchSize) :
return random.sample(self.memory, batch_size) start = random.randint(0, len(self)-batchSize)
end = start+batchSize
return self.states[start:end], self.actions[start:end], self.newStates[start:end], self.rewards[start:end]
def __len__(self): def __len__(self):
return len(self.memory) return min(self.nbPushed, self.capacity)
################################################################################ ################################################################################
################################################################################ ################################################################################
...@@ -47,14 +54,10 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) : ...@@ -47,14 +54,10 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
if len(memory) < batchSize : if len(memory) < batchSize :
return 0.0 return 0.0
batch = memory.sample(batchSize) states, actions, nextStates, rewards = memory.sample(batchSize)
states = torch.stack([b[0] for b in batch])
actions = torch.stack([b[1] for b in batch])
next_states = torch.stack([b[2] for b in batch])
rewards = torch.stack([b[3] for b in batch])
predictedQ = policy_net(states).gather(1, actions) predictedQ = policy_net(states).gather(1, actions)
nextQ = target_net(next_states).max(1)[0].unsqueeze(0) nextQ = target_net(nextStates).max(1)[0].unsqueeze(0)
nextQ = torch.transpose(nextQ, 0, 1) nextQ = torch.transpose(nextQ, 0, 1)
expectedReward = gamma*nextQ + rewards expectedReward = gamma*nextQ + rewards
......
...@@ -128,20 +128,15 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr ...@@ -128,20 +128,15 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
################################################################################ ################################################################################
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, silent=False) : def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, silent=False) :
memory = ReplayMemory(1000) memory = None
dicts = Dicts() dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"]) dicts.readConllu(filename, ["FORM", "UPOS"])
dicts.save(modelDir + "/dicts.json") dicts.save(modelDir + "/dicts.json")
policy_net = Networks.BaseNet(dicts, 13, len(transitionSet)) policy_net = None
target_net = Networks.BaseNet(dicts, 13, len(transitionSet)) target_net = None
target_net.load_state_dict(policy_net.state_dict()) optimizer = None
target_net.eval()
policy_net.train()
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
bestLoss = None bestLoss = None
bestScore = None bestScore = None
...@@ -155,6 +150,16 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -155,6 +150,16 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
sentence = sentences[sentIndex] sentence = sentences[sentIndex]
sentence.moveWordIndex(0) sentence.moveWordIndex(0)
state = Features.extractFeaturesPosExtended(dicts, sentence) state = Features.extractFeaturesPosExtended(dicts, sentence)
if policy_net is None :
policy_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet))
target_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet))
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
while True : while True :
missingLinks = getMissingLinks(sentence) missingLinks = getMissingLinks(sentence)
if debug : if debug :
...@@ -169,7 +174,9 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -169,7 +174,9 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
applyTransition(transitionSet, strategy, sentence, action.name) applyTransition(transitionSet, strategy, sentence, action.name)
newState = Features.extractFeaturesPosExtended(dicts, sentence) newState = Features.extractFeaturesPosExtended(dicts, sentence)
memory.push((state, torch.LongTensor([transitionSet.index(action)]), newState, reward)) if memory is None :
memory = ReplayMemory(1000, state.numel())
memory.push(state, torch.LongTensor([transitionSet.index(action)]), newState, reward)
state = newState state = newState
if i % batchSize == 0 : if i % batchSize == 0 :
totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer) totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment