diff --git a/Rl.py b/Rl.py
index 4767c30b89865d2dcc87506daa169f70b6f54173..e091df525f4e5805aae75bd9d437c5fff0b7bbe0 100644
--- a/Rl.py
+++ b/Rl.py
@@ -7,7 +7,9 @@ from Util import getDevice
 
 ################################################################################
 class ReplayMemory() :
-  def __init__(self, capacity, stateSize, nbStates) :
+  def __init__(self, capacity, stateSize, fromState, toState) :
+    self.fromState = fromState
+    self.toState = toState
     self.capacity = capacity
     self.states = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
     self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
@@ -37,7 +39,7 @@ class ReplayMemory() :
 ################################################################################
 
 ################################################################################
-def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle) :
+def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle, fromState) :
   sample = random.random()
   if sample < probaRandom :
     return ts[random.randrange(len(ts))]
@@ -46,6 +48,7 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra
     return candidates[0][1] if len(candidates) > 0 else None
   else :
     with torch.no_grad() :
+      network.setState(fromState)
       output = network(torch.stack([state]))
       predIndex = int(torch.argmax(output))
       return ts[predIndex]
@@ -53,27 +56,35 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra
 
 ################################################################################
 def optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma) :
-  if len(memory) < batchSize :
-    return 0.0
-
-  states, actions, nextStates, noNextStates, rewards = memory.sample(batchSize)
-
-  predictedQ = policy_net(states).gather(1, actions)
-  nextQ = target_net(nextStates).max(1)[0].detach().unsqueeze(0)
-  nextQ = torch.transpose(nextQ, 0, 1)
-  nextQ[noNextStates] = 0.0
-
-  expectedReward = gamma*nextQ + rewards
-
-  loss = F.smooth_l1_loss(predictedQ, expectedReward)
-  optimizer.zero_grad()
-  loss.backward()
-  for param in policy_net.parameters() :
-    if param.grad is not None :
-      param.grad.data.clamp_(-1, 1)
-  optimizer.step()
-
-  return float(loss)
+  totalLoss = 0.0
+  for fromState in range(len(memory)) :
+    for toState in range(len(memory[fromState])) :
+      if memory[fromState][toState].nbPushed < batchSize :
+        continue
+    
+      states, actions, nextStates, noNextStates, rewards = memory[fromState][toState].sample(batchSize)
+
+      policy_net.setState(fromState)
+      target_net.setState(toState)
+    
+      predictedQ = policy_net(states).gather(1, actions)
+      nextQ = target_net(nextStates).max(1)[0].detach().unsqueeze(0)
+      nextQ = torch.transpose(nextQ, 0, 1)
+      nextQ[noNextStates] = 0.0
+    
+      expectedReward = gamma*nextQ + rewards
+    
+      loss = F.smooth_l1_loss(predictedQ, expectedReward)
+      optimizer.zero_grad()
+      loss.backward()
+      for param in policy_net.parameters() :
+        if param.grad is not None :
+          param.grad.data.clamp_(-1, 1)
+      optimizer.step()
+
+      totalLoss += float(loss)
+
+  return totalLoss
 ################################################################################
 
 ################################################################################
diff --git a/Train.py b/Train.py
index de709df3a709fecb7917723722fac1a50fc48466..c6b57bf4ab85bfebebf16b33a5e1a29ae1eab2ec 100644
--- a/Train.py
+++ b/Train.py
@@ -235,9 +235,12 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
       while True :
         missingLinks = getMissingLinks(sentence)
         transitionSet = transitionSets[sentence.state]
+        fromState = sentence.state
+        toState = sentence.state
+
         if debug :
           sentence.printForDebug(sys.stderr)
-        action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle)
+        action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle, fromState)
 
         if action is None :
           break
@@ -253,14 +256,14 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
         newState = None
         if appliable :
           applyTransition(strategy, sentence, action, reward_)
+          toState = sentence.state
           newState = policy_net.extractFeatures(dicts, sentence).to(getDevice())
-
         else:
           count+=1
 
         if memory is None :
-          memory = ReplayMemory(5000, state.numel())
-        memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
+          memory = [[ReplayMemory(5000, state.numel(), f, t) for t in range(len(transitionSets))] for f in range(len(transitionSets))]
+        memory[fromState][toState].push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
         state = newState
         if i % batchSize == 0 :
           totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer, gamma)