diff --git a/.gitignore b/.gitignore
index 58c6276ce78646b945bc5243cfdc1c613d6ef6e7..ee244d0958c12bf65a1bc110f40bccdf553aa6f7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
 __pycache__
 bin/*
+.idea
diff --git a/Rl.py b/Rl.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfd63b171f7d9b443f7d7341408b1e50451e2652
--- /dev/null
+++ b/Rl.py
@@ -0,0 +1,39 @@
+import random
+import torch
+
+################################################################################
+class ReplayMemory(object):
+
+    def __init__(self, capacity):
+        self.capacity = capacity
+        self.memory = []
+        self.position = 0
+
+    def push(self, transition):
+        """Saves a transition."""
+        if len(self.memory) < self.capacity:
+            self.memory.append(None)
+        self.memory[self.position] = transition
+        self.position = (self.position + 1) % self.capacity
+
+    def sample(self, batch_size):
+        return random.sample(self.memory, batch_size)
+
+    def __len__(self):
+        return len(self.memory)
+
+################################################################################
+
+################################################################################
+
+def selectAction(network, state, ts):
+    sample = random.random()
+    if sample > .2:
+        with torch.no_grad():
+            return ts[max(torch.nn.functional.softmax(network(state), dim=1))].name
+
+    else:
+        return ts[random.randrange(len(ts))].name
+
+
+################################################################################
\ No newline at end of file
diff --git a/Train.py b/Train.py
index 6d192121d5bac7ba72b1526935a3371c32b9c78a..4b30067a523620c7f67d0d66f3f5bdf4158aeba4 100644
--- a/Train.py
+++ b/Train.py
@@ -6,6 +6,7 @@ from Transition import Transition, getMissingLinks, applyTransition
 import Features
 from Dicts import Dicts
 from Util import timeStamp
+from Rl import ReplayMemory, selectAction
 import Networks
 import Decode
 import Config
@@ -23,6 +24,10 @@ def trainMode(debug, filename, type, modelDir, nbIter, batchSize, devFile, silen
     trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent)
     return
 
+  if type == "rl":
+    trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent)
+    return
+
   print("ERROR : unknown type '%s'"%type, file=sys.stderr)
   exit(1)
 ################################################################################
@@ -110,3 +115,38 @@ def trainModelOracle(debug, modelDir, filename, nbIter, batchSize, devFile, tran
     print("%s : Epoch %d, loss=%.2f%s %s"%(timeStamp(), iter, totalLoss, devScore, "SAVED" if saved else ""), file=sys.stderr)
 ################################################################################
 
+def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentences, silent=False) :
+  memory = ReplayMemory(1000)
+  dicts = Dicts()
+  dicts.readConllu(filename, ["FORM", "UPOS"])
+  dicts.save(modelDir + "/dicts.json")
+
+  policy_net = Networks.BaseNet(dicts, 13, len(transitionSet))
+  target_net = Networks.BaseNet(dicts, 13, len(transitionSet))
+  target_net.load_state_dict(policy_net.state_dict())
+  target_net.eval()
+
+  optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
+  lossFct = torch.nn.CrossEntropyLoss()
+  bestLoss = None
+  bestScore = None
+
+  for i_episode in range(nbIter):
+    sentence = sentences[i_episode%len(sentences)]
+    state = Features.extractFeaturesPosExtended(dicts, sentence)
+    notDone = True
+    while notDone:
+      action = selectAction(policy_net, state, transitionSet)
+      print(action, file=sys.stderr)
+      notDone = applyTransition(transitionSet, strategy, sentence, action)
+      reward = getReward(state, newState)
+      reward = torch.tensor([reward])
+
+      if notDone:
+        newState = Features.extractFeaturesPosExtended(dicts, sentence)
+      else:
+        newState = None
+
+      memory.push((state, action, newState, reward))
+      state = newState
+      optimizeModel()