diff --git a/Decode.py b/Decode.py
index 1f48a2986fc1faa76ded503801283ba6b486f9c3..c2e8df021b9c65536923cd5bc38fcf45d440ed10 100644
--- a/Decode.py
+++ b/Decode.py
@@ -57,15 +57,15 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
   with torch.no_grad():
     while moved :
       features = extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice)
-      output = torch.nn.functional.softmax(network(features), dim=1)
-      scores = sorted([["%.2f"%float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1]
+      output = network(features)
+      scores = sorted([[float(output[0][index]), ts[index].appliable(config), ts[index].name] for index in range(len(ts))])[::-1]
       candidates = [[cand[0],cand[2]] for cand in scores if cand[1]]
       if len(candidates) == 0 :
         break
       candidate = candidates[0][1]
       if debug :
         config.printForDebug(sys.stderr)
-        print(" ".join(["%s%s:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+("-"*80)+"\n", file=sys.stderr)
+        print(" ".join(["%s%.2f:%s"%("*" if score[1] else " ", score[0], score[2]) for score in scores])+"\n"+("-"*80)+"\n", file=sys.stderr)
       moved = applyTransition(ts, strat, config, candidate)
 
   EOS.apply(config)
diff --git a/Rl.py b/Rl.py
index 59bdf4d8855d3ac1e8f425093730663278c0fd36..d1f63d61c3bcde60c48326cbae66c61f0f297c5b 100644
--- a/Rl.py
+++ b/Rl.py
@@ -12,13 +12,16 @@ class ReplayMemory() :
     self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
     self.actions = torch.zeros(capacity, 1, dtype=torch.long, device=getDevice())
     self.rewards = torch.zeros(capacity, 1, device=getDevice())
+    self.noNewStates = torch.zeros(capacity, dtype=torch.bool, device=getDevice())
     self.position = 0
     self.nbPushed = 0
 
   def push(self, state, action, newState, reward) :
     self.states[self.position] = state
     self.actions[self.position] = action
-    self.newStates[self.position] = newState
+    if newState is not None :
+      self.newStates[self.position] = newState
+    self.noNewStates[self.position] = newState is None
     self.rewards[self.position] = reward 
     self.position = (self.position + 1) % self.capacity
     self.nbPushed += 1
@@ -26,7 +29,7 @@ class ReplayMemory() :
   def sample(self, batchSize) :
     start = random.randint(0, len(self)-batchSize)
     end = start+batchSize
-    return self.states[start:end], self.actions[start:end], self.newStates[start:end], self.rewards[start:end]
+    return self.states[start:end], self.actions[start:end], self.newStates[start:end], self.noNewStates[start:end], self.rewards[start:end]
 
   def __len__(self):
     return min(self.nbPushed, self.capacity)
@@ -36,30 +39,29 @@ class ReplayMemory() :
 def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOracle) :
   sample = random.random()
   if sample < probaRandom :
-    candidates = [trans for trans in ts if trans.appliable(config)]
-    return candidates[random.randrange(len(candidates))] if len(candidates) > 0 else None
+    return ts[random.randrange(len(ts))]
   elif sample < probaRandom+probaOracle :
     candidates = sorted([[trans.getOracleScore(config, missingLinks), trans] for trans in ts if trans.appliable(config)])
     return candidates[0][1] if len(candidates) > 0 else None
   else :
     with torch.no_grad() :
       output = network(torch.stack([state]))
-      candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index]] for index in range(len(ts))])[::-1]
-      candidates = [cand[2] for cand in candidates if cand[0]]
-      return candidates[0] if len(candidates) > 0 else None
+      predIndex = int(torch.argmax(output))
+      return ts[predIndex]
 ################################################################################
 
 ################################################################################
 def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
-  gamma = 0.999
+  gamma = 0.9
   if len(memory) < batchSize :
     return 0.0
 
-  states, actions, nextStates, rewards = memory.sample(batchSize)
+  states, actions, nextStates, noNextStates, rewards = memory.sample(batchSize)
 
   predictedQ = policy_net(states).gather(1, actions)
   nextQ = target_net(nextStates).max(1)[0].detach().unsqueeze(0)
   nextQ = torch.transpose(nextQ, 0, 1)
+  nextQ[noNextStates] = 0.0
 
   expectedReward = gamma*nextQ + rewards
 
diff --git a/Train.py b/Train.py
index 6a28f41564d2b9eb214d18d7e41d97b041a44466..b586dbcdacad70bff17f20d4fe7d7ed2b34d9799 100644
--- a/Train.py
+++ b/Train.py
@@ -140,13 +140,21 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
   bestLoss = None
   bestScore = None
 
+  sentences = copy.deepcopy(sentencesOriginal)
+  nbExByEpoch = sum(map(len,sentences))
+  sentIndex = 0
+
   for epoch in range(1,nbIter+1) :
     i = 0
     totalLoss = 0.0
-    sentences = copy.deepcopy(sentencesOriginal)
-    for sentIndex in range(len(sentences)) :
+    while True :
+      if sentIndex >= len(sentences) :
+        sentences = copy.deepcopy(sentencesOriginal)
+        random.shuffle(sentences)
+        sentIndex = 0
+
       if not silent :
-        print("Curent epoch %6.2f%%"%(100.0*sentIndex/len(sentences)), end="\r", file=sys.stderr)
+        print("Curent epoch %6.2f%%"%(100.0*i/nbExByEpoch), end="\r", file=sys.stderr)
       sentence = sentences[sentIndex]
       sentence.moveWordIndex(0)
       state = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice())
@@ -168,14 +176,22 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
         if action is None :
           break
 
-        reward = -1.0*action.getOracleScore(sentence, missingLinks)
+        appliable = action.appliable(sentence)
+
+        # Reward for doing an illegal action
+        reward = -3.0
+        if appliable :
+          reward = -1.0*action.getOracleScore(sentence, missingLinks)
+      
         reward = torch.FloatTensor([reward]).to(getDevice())
 
-        applyTransition(transitionSet, strategy, sentence, action.name)
-        newState = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice())
+        newState = None
+        if appliable :
+          applyTransition(transitionSet, strategy, sentence, action.name)
+          newState = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice())
 
         if memory is None :
-          memory = ReplayMemory(1000, state.numel())
+          memory = ReplayMemory(5000, state.numel())
         memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
         state = newState
         if i % batchSize == 0 :
@@ -185,6 +201,12 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
             target_net.eval()
             policy_net.train()
         i += 1
+
+        if state is None :
+          break
+      if i >= nbExByEpoch :
+        break
+      sentIndex += 1
     bestLoss, bestScore = evalModelAndSave(debug, policy_net, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter)
 ################################################################################