Skip to content
Snippets Groups Projects
Commit f336968b authored by Franck Dary's avatar Franck Dary
Browse files

States working for RL

parent 67435b57
No related branches found
No related tags found
No related merge requests found
......@@ -7,7 +7,9 @@ from Util import getDevice
################################################################################
class ReplayMemory() :
def __init__(self, capacity, stateSize, nbStates) :
def __init__(self, capacity, stateSize, fromState, toState) :
self.fromState = fromState
self.toState = toState
self.capacity = capacity
self.states = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
......@@ -37,7 +39,7 @@ class ReplayMemory() :
################################################################################
################################################################################
def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle) :
def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle, fromState) :
sample = random.random()
if sample < probaRandom :
return ts[random.randrange(len(ts))]
......@@ -46,6 +48,7 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra
return candidates[0][1] if len(candidates) > 0 else None
else :
with torch.no_grad() :
network.setState(fromState)
output = network(torch.stack([state]))
predIndex = int(torch.argmax(output))
return ts[predIndex]
......@@ -53,10 +56,16 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra
################################################################################
def optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) :
if len(memory) < batchSize :
return 0.0
totalLoss = 0.0
for fromState in range(len(memory)) :
for toState in range(len(memory[fromState])) :
if memory[fromState][toState].nbPushed < batchSize :
continue
states, actions, nextStates, noNextStates, rewards = memory.sample(batchSize)
states, actions, nextStates, noNextStates, rewards = memory[fromState][toState].sample(batchSize)
policy_net.setState(fromState)
target_net.setState(toState)
predictedQ = policy_net(states).gather(1, actions)
nextQ = target_net(nextStates).max(1)[0].detach().unsqueeze(0)
......@@ -73,7 +82,9 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) :
param.grad.data.clamp_(-1, 1)
optimizer.step()
return float(loss)
totalLoss += float(loss)
return totalLoss
################################################################################
################################################################################
......
......@@ -235,9 +235,12 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
while True :
missingLinks = getMissingLinks(sentence)
transitionSet = transitionSets[sentence.state]
fromState = sentence.state
toState = sentence.state
if debug :
sentence.printForDebug(sys.stderr)
action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle)
action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle, fromState)
if action is None :
break
......@@ -253,14 +256,14 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
newState = None
if appliable :
applyTransition(strategy, sentence, action, reward_)
toState = sentence.state
newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
else:
count+=1
if memory is None :
memory = ReplayMemory(5000, state.numel())
memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
memory = [[ReplayMemory(5000, state.numel(), f, t) for t in range(len(transitionSets))] for f in range(len(transitionSets))]
memory[fromState][toState].push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
state = newState
if i % batchSize == 0 :
totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment