diff --git a/Rl.py b/Rl.py
index fec4da33f276a0bcf82a9f3219903c6616a089fe..8382e315862758b99f174cbeb16e55601daa53cc 100644
--- a/Rl.py
+++ b/Rl.py
@@ -3,24 +3,31 @@ import torch
 import torch.nn.functional as F
 
 ################################################################################
-class ReplayMemory(object):
-  def __init__(self, capacity):
+class ReplayMemory() :
+  def __init__(self, capacity, stateSize) :
     self.capacity = capacity
-    self.memory = []
+    self.states = torch.zeros(capacity, stateSize, dtype=torch.long)
+    self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long)
+    self.actions = torch.zeros(capacity, 1, dtype=torch.long)
+    self.rewards = torch.zeros(capacity, 1)
     self.position = 0
+    self.nbPushed = 0
 
-  def push(self, transition):
-    """Saves a transition."""
-    if len(self.memory) < self.capacity:
-      self.memory.append(None)
-    self.memory[self.position] = transition
+  def push(self, state, action, newState, reward) :
+    self.states[self.position] = state
+    self.actions[self.position] = action
+    self.newStates[self.position] = newState
+    self.rewards[self.position] = reward 
     self.position = (self.position + 1) % self.capacity
+    self.nbPushed += 1
 
-  def sample(self, batch_size):
-    return random.sample(self.memory, batch_size)
+  def sample(self, batchSize) :
+    start = random.randint(0, len(self)-batchSize)
+    end = start+batchSize
+    return self.states[start:end], self.actions[start:end], self.newStates[start:end], self.rewards[start:end]
 
   def __len__(self):
-    return len(self.memory)
+    return min(self.nbPushed, self.capacity)
 ################################################################################
 
 ################################################################################
@@ -47,14 +54,10 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
   if len(memory) < batchSize :
     return 0.0
   
-  batch = memory.sample(batchSize)
-  states = torch.stack([b[0] for b in batch])
-  actions = torch.stack([b[1] for b in batch])
-  next_states = torch.stack([b[2] for b in batch])
-  rewards = torch.stack([b[3] for b in batch])
+  states, actions, nextStates, rewards = memory.sample(batchSize)
 
   predictedQ = policy_net(states).gather(1, actions)
-  nextQ = target_net(next_states).max(1)[0].unsqueeze(0)
+  nextQ = target_net(nextStates).max(1)[0].unsqueeze(0)
   nextQ = torch.transpose(nextQ, 0, 1)
 
   expectedReward = gamma*nextQ + rewards
diff --git a/Train.py b/Train.py
index b61ecd39aafd4133f6ff3c013a55991af36f3e3f..838a71aac6309fd6d2e377370f8e3df7eaaefe96 100644
--- a/Train.py
+++ b/Train.py
@@ -128,20 +128,15 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
 ################################################################################
 def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transitionSet, strategy, sentencesOriginal, silent=False) :
 
-  memory = ReplayMemory(1000)
+  memory = None
   dicts = Dicts()
   dicts.readConllu(filename, ["FORM", "UPOS"])
   dicts.save(modelDir + "/dicts.json")
 
-  policy_net = Networks.BaseNet(dicts, 13, len(transitionSet))
-  target_net = Networks.BaseNet(dicts, 13, len(transitionSet))
-  target_net.load_state_dict(policy_net.state_dict())
-  target_net.eval()
-  policy_net.train()
+  policy_net = None
+  target_net = None
+  optimizer = None
 
-  print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
-
-  optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
   bestLoss = None
   bestScore = None
 
@@ -155,6 +150,16 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
       sentence = sentences[sentIndex]
       sentence.moveWordIndex(0)
       state = Features.extractFeaturesPosExtended(dicts, sentence)
+
+      if policy_net is None :
+        policy_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet))
+        target_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet))
+        target_net.load_state_dict(policy_net.state_dict())
+        target_net.eval()
+        policy_net.train()
+        optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0001)
+        print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
+
       while True :
         missingLinks = getMissingLinks(sentence)
         if debug :
@@ -169,7 +174,9 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
         applyTransition(transitionSet, strategy, sentence, action.name)
         newState = Features.extractFeaturesPosExtended(dicts, sentence)
 
-        memory.push((state, torch.LongTensor([transitionSet.index(action)]), newState, reward))
+        if memory is None :
+          memory = ReplayMemory(1000, state.numel())
+        memory.push(state, torch.LongTensor([transitionSet.index(action)]), newState, reward)
         state = newState
         if i % batchSize == 0 :
           totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer)