diff --git a/Rl.py b/Rl.py
index cfd63b171f7d9b443f7d7341408b1e50451e2652..fec4da33f276a0bcf82a9f3219903c6616a089fe 100644
--- a/Rl.py
+++ b/Rl.py
@@ -1,39 +1,69 @@
 import random
 import torch
+import torch.nn.functional as F
 
 ################################################################################
 class ReplayMemory(object):
+  def __init__(self, capacity):
+    self.capacity = capacity
+    self.memory = []
+    self.position = 0
 
-    def __init__(self, capacity):
-        self.capacity = capacity
-        self.memory = []
-        self.position = 0
+  def push(self, transition):
+    """Saves a transition."""
+    if len(self.memory) < self.capacity:
+      self.memory.append(None)
+    self.memory[self.position] = transition
+    self.position = (self.position + 1) % self.capacity
 
-    def push(self, transition):
-        """Saves a transition."""
-        if len(self.memory) < self.capacity:
-            self.memory.append(None)
-        self.memory[self.position] = transition
-        self.position = (self.position + 1) % self.capacity
+  def sample(self, batch_size):
+    return random.sample(self.memory, batch_size)
 
-    def sample(self, batch_size):
-        return random.sample(self.memory, batch_size)
+  def __len__(self):
+    return len(self.memory)
+################################################################################
 
-    def __len__(self):
-        return len(self.memory)
+################################################################################
+def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle) :
+  sample = random.random()
+  if sample < probaRandom :
+    candidates = [trans for trans in ts if trans.appliable(config)]
+    return candidates[random.randrange(len(candidates))] if len(candidates) > 0 else None
+  elif sample < probaRandom+probaOracle :
+    candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)])
+    return candidates[0][1] if len(candidates) > 0 else None
+  else :
+    with torch.no_grad() :
+      output = network(torch.stack([state]))
+      candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index]] for index in range(len(ts))])[::-1]
+      candidates = [cand[2] for cand in candidates if cand[0]]
+      return candidates[0] if len(candidates) > 0 else None
 
 ################################################################################
 
 ################################################################################
+def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
+  gamma = 0.999
+  if len(memory) < batchSize :
+    return 0.0
+  
+  batch = memory.sample(batchSize)
+  states = torch.stack([b[0] for b in batch])
+  actions = torch.stack([b[1] for b in batch])
+  next_states = torch.stack([b[2] for b in batch])
+  rewards = torch.stack([b[3] for b in batch])
 
-def selectAction(network, state, ts):
-    sample = random.random()
-    if sample > .2:
-        with torch.no_grad():
-            return ts[max(torch.nn.functional.softmax(network(state), dim=1))].name
+  predictedQ = policy_net(states).gather(1, actions)
+  nextQ = target_net(next_states).max(1)[0].unsqueeze(0)
+  nextQ = torch.transpose(nextQ, 0, 1)
 
-    else:
-        return ts[random.randrange(len(ts))].name
+  expectedReward = gamma*nextQ + rewards
 
+  loss = F.smooth_l1_loss(predictedQ, expectedReward)
+  optimizer.zero_grad()
+  loss.backward()
+
+  optimizer.step()
+  return float(loss)
+################################################################################
 
-################################################################################
\ No newline at end of file
diff --git a/Train.py b/Train.py
index 4b30067a523620c7f67d0d66f3f5bdf4158aeba4..689dca8e5bf44d0afd473cf3942213f92156c947 100644
--- a/Train.py
+++ b/Train.py
@@ -1,12 +1,13 @@
 import sys
 import random
 import torch
+import copy
 
 from Transition import Transition, getMissingLinks, applyTransition
 import Features
 from Dicts import Dicts
 from Util import timeStamp
-from Rl import ReplayMemory, selectAction
+from Rl import ReplayMemory, selectAction, optimizeModel
 import Networks
 import Decode
 import Config
@@ -115,7 +116,9 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran
     print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
 ################################################################################
 
-def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) :
+################################################################################
+def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, silent=False) :
+
   memory = ReplayMemory(1000)
   dicts = Dicts()
   dicts.readConllu(filename, ["FORM", "UPOS"])
@@ -125,28 +128,61 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
   target_net = Networks.BaseNet(dicts, 13, len(transitionSet))
   target_net.load_state_dict(policy_net.state_dict())
   target_net.eval()
+  policy_net.train()
 
   optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
-  lossFct = torch.nn.CrossEntropyLoss()
   bestLoss = None
   bestScore = None
 
-  for i_episode in range(nbIter):
-    sentence = sentences[i_episode%len(sentences)]
-    state = Features.extractFeaturesPosExtended(dicts, sentence)
-    notDone = True
-    while notDone:
-      action = selectAction(policy_net, state, transitionSet)
-      print(action, file=sys.stderr)
-      notDone = applyTransition(transitionSet, strategy, sentence, action)
-      reward = getReward(state, newState)
-      reward = torch.tensor([reward])
-
-      if notDone:
+  for epoch in range(nbIter) :
+    i = 0
+    totalLoss = 0.0
+    sentences = copy.deepcopy(sentencesOriginal)
+    for sentIndex in range(len(sentences)) :
+      if not silent :
+        print("Curent epoch %6.2f%%"%(100.0*sentIndex/len(sentences)), end="\r", file=sys.stderr)
+      sentence = sentences[sentIndex]
+      sentence.moveWordIndex(0)
+      state = Features.extractFeaturesPosExtended(dicts, sentence)
+      while True :
+        missingLinks = getMissingLinks(sentence)
+        if debug :
+          sentence.printForDebug(sys.stderr)
+        action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom=0.3, probaOracle=0.15)
+        if action is None :
+          break
+
+        reward = -1.0*action.getOracleScore(sentence, missingLinks)
+        reward = torch.FloatTensor([reward])
+
+        applyTransition(transitionSet, strategy, sentence, action.name)
         newState = Features.extractFeaturesPosExtended(dicts, sentence)
-      else:
-        newState = None
 
-      memory.push((state, action, newState, reward))
-      state = newState
-      optimizeModel()
+        memory.push((state, torch.LongTensor([transitionSet.index(action)]), newState, reward))
+        state = newState
+        if i % batchSize == 0 :
+          totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer)
+          if i % (2*batchSize) == 0 :
+            target_net.load_state_dict(policy_net.state_dict())
+            target_net.eval()
+            policy_net.train()
+        i += 1
+    # Fin epoch, compute score and save model
+    devScore = ""
+    saved = True if bestLoss is None else totalLoss < bestLoss
+    bestLoss = totalLoss if bestLoss is None else min(bestLoss, totalLoss)
+    if devFile is not None :
+      outFilename = modelDir+"/predicted_dev.conllu"
+      Decode.decodeMode(debug, devFile, "model", modelDir, policy_net, dicts, open(outFilename, "w"))
+      res = evaluate(load_conllu(open(devFile, "r")), load_conllu(open(outFilename, "r")), [])
+      UAS = res["UAS"][0].f1
+      score = UAS
+      saved = True if bestScore is None else score > bestScore
+      bestScore = score if bestScore is None else max(bestScore, score)
+      devScore = ", Dev : UAS=%.2f"%(UAS)
+    if saved :
+      torch.save(policy_net, modelDir+"/network.pt")
+    print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), epoch, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
+
+################################################################################
+
diff --git a/Transition.py b/Transition.py
index 618a9af419c5d41c3e815d46f5ec075c66ba0233..af1324aa88940a4aa438de51168f34bc2268e99f 100644
--- a/Transition.py
+++ b/Transition.py
@@ -12,6 +12,9 @@ class Transition :
       exit(1)
     self.name = name
 
+  def __lt__(self, other) :
+    return self.name < other.name
+
   def apply(self, config) :
     if self.name == "RIGHT" :
       applyRight(config)