Skip to content
Snippets Groups Projects
Commit f336968b authored by Franck Dary's avatar Franck Dary
Browse files

States working for RL

parent 67435b57
Branches
No related tags found
No related merge requests found
...@@ -7,7 +7,9 @@ from Util import getDevice ...@@ -7,7 +7,9 @@ from Util import getDevice
################################################################################ ################################################################################
class ReplayMemory() : class ReplayMemory() :
def __init__(self, capacity, stateSize, nbStates) : def __init__(self, capacity, stateSize, fromState, toState) :
self.fromState = fromState
self.toState = toState
self.capacity = capacity self.capacity = capacity
self.states = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice()) self.states = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice()) self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
...@@ -37,7 +39,7 @@ class ReplayMemory() : ...@@ -37,7 +39,7 @@ class ReplayMemory() :
################################################################################ ################################################################################
################################################################################ ################################################################################
def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle) : def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle, fromState) :
sample = random.random() sample = random.random()
if sample < probaRandom : if sample < probaRandom :
return ts[random.randrange(len(ts))] return ts[random.randrange(len(ts))]
...@@ -46,6 +48,7 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra ...@@ -46,6 +48,7 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra
return candidates[0][1] if len(candidates) > 0 else None return candidates[0][1] if len(candidates) > 0 else None
else : else :
with torch.no_grad() : with torch.no_grad() :
network.setState(fromState)
output = network(torch.stack([state])) output = network(torch.stack([state]))
predIndex = int(torch.argmax(output)) predIndex = int(torch.argmax(output))
return ts[predIndex] return ts[predIndex]
...@@ -53,10 +56,16 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra ...@@ -53,10 +56,16 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra
################################################################################ ################################################################################
def optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) : def optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) :
if len(memory) < batchSize : totalLoss = 0.0
return 0.0 for fromState in range(len(memory)) :
for toState in range(len(memory[fromState])) :
if memory[fromState][toState].nbPushed < batchSize :
continue
states, actions, nextStates, noNextStates, rewards = memory.sample(batchSize) states, actions, nextStates, noNextStates, rewards = memory[fromState][toState].sample(batchSize)
policy_net.setState(fromState)
target_net.setState(toState)
predictedQ = policy_net(states).gather(1, actions) predictedQ = policy_net(states).gather(1, actions)
nextQ = target_net(nextStates).max(1)[0].detach().unsqueeze(0) nextQ = target_net(nextStates).max(1)[0].detach().unsqueeze(0)
...@@ -73,7 +82,9 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) : ...@@ -73,7 +82,9 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) :
param.grad.data.clamp_(-1, 1) param.grad.data.clamp_(-1, 1)
optimizer.step() optimizer.step()
return float(loss) totalLoss += float(loss)
return totalLoss
################################################################################ ################################################################################
################################################################################ ################################################################################
......
...@@ -235,9 +235,12 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -235,9 +235,12 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
while True : while True :
missingLinks = getMissingLinks(sentence) missingLinks = getMissingLinks(sentence)
transitionSet = transitionSets[sentence.state] transitionSet = transitionSets[sentence.state]
fromState = sentence.state
toState = sentence.state
if debug : if debug :
sentence.printForDebug(sys.stderr) sentence.printForDebug(sys.stderr)
action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle) action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle, fromState)
if action is None : if action is None :
break break
...@@ -253,14 +256,14 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti ...@@ -253,14 +256,14 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
newState = None newState = None
if appliable : if appliable :
applyTransition(strategy, sentence, action, reward_) applyTransition(strategy, sentence, action, reward_)
toState = sentence.state
newState = policy_net.extractFeatures(dicts, sentence).to(getDevice()) newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
else: else:
count+=1 count+=1
if memory is None : if memory is None :
memory = ReplayMemory(5000, state.numel()) memory = [[ReplayMemory(5000, state.numel(), f, t) for t in range(len(transitionSets))] for f in range(len(transitionSets))]
memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward) memory[fromState][toState].push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
state = newState state = newState
if i % batchSize == 0 : if i % batchSize == 0 :
totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment