Skip to content
Snippets Groups Projects
Commit 18082daf authored by Franck Dary's avatar Franck Dary
Browse files

Improved Replay Memory

parent bc508cdb
No related branches found
No related tags found
No related merge requests found
......@@ -3,24 +3,31 @@ import torch
import torch.nn.functional as F
################################################################################
class ReplayMemory(object):
def __init__(self, capacity):
class ReplayMemory() :
def __init__(self, capacity, stateSize) :
self.capacity = capacity
self.memory = []
self.states = torch.zeros(capacity, stateSize, dtype=torch.long)
self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long)
self.actions = torch.zeros(capacity, 1, dtype=torch.long)
self.rewards = torch.zeros(capacity, 1)
self.position = 0
self.nbPushed = 0
def push(self, transition):
"""Saves a transition."""
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = transition
def push(self, state, action, newState, reward) :
self.states[self.position] = state
self.actions[self.position] = action
self.newStates[self.position] = newState
self.rewards[self.position] = reward
self.position = (self.position + 1) % self.capacity
self.nbPushed += 1
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def sample(self, batchSize) :
start = random.randint(0, len(self)-batchSize)
end = start+batchSize
return self.states[start:end], self.actions[start:end], self.newStates[start:end], self.rewards[start:end]
def __len__(self):
return len(self.memory)
return min(self.nbPushed, self.capacity)
################################################################################
################################################################################
......@@ -47,14 +54,10 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
if len(memory) < batchSize :
return 0.0
batch = memory.sample(batchSize)
states = torch.stack([b[0] for b in batch])
actions = torch.stack([b[1] for b in batch])
next_states = torch.stack([b[2] for b in batch])
rewards = torch.stack([b[3] for b in batch])
states, actions, nextStates, rewards = memory.sample(batchSize)
predictedQ = policy_net(states).gather(1, actions)
nextQ = target_net(next_states).max(1)[0].unsqueeze(0)
nextQ = target_net(nextStates).max(1)[0].unsqueeze(0)
nextQ = torch.transpose(nextQ, 0, 1)
expectedReward = gamma*nextQ + rewards
......
......@@ -128,20 +128,15 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
################################################################################
def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, silent=False) :
memory = ReplayMemory(1000)
memory = None
dicts = Dicts()
dicts.readConllu(filename, ["FORM", "UPOS"])
dicts.save(modelDir + "/dicts.json")
policy_net = Networks.BaseNet(dicts, 13, len(transitionSet))
target_net = Networks.BaseNet(dicts, 13, len(transitionSet))
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
policy_net = None
target_net = None
optimizer = None
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
bestLoss = None
bestScore = None
......@@ -155,6 +150,16 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
sentence = sentences[sentIndex]
sentence.moveWordIndex(0)
state = Features.extractFeaturesPosExtended(dicts, sentence)
if policy_net is None :
policy_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet))
target_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet))
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
policy_net.train()
optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
while True :
missingLinks = getMissingLinks(sentence)
if debug :
......@@ -169,7 +174,9 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
applyTransition(transitionSet, strategy, sentence, action.name)
newState = Features.extractFeaturesPosExtended(dicts, sentence)
memory.push((state, torch.LongTensor([transitionSet.index(action)]), newState, reward))
if memory is None :
memory = ReplayMemory(1000, state.numel())
memory.push(state, torch.LongTensor([transitionSet.index(action)]), newState, reward)
state = newState
if i % batchSize == 0 :
totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment