diff --git a/Rl.py b/Rl.py
index c3f199ed519f7a08ab62702e7351f63725b92e81..82272834b67a533537329e5bd9f34ea2a0e91ce1 100644
--- a/Rl.py
+++ b/Rl.py
@@ -1,8 +1,10 @@
+import os
 import sys
 import random
 import torch
 import torch.nn.functional as F
 import numpy as np
+import json
 from Util import getDevice
 
 ################################################################################
@@ -34,6 +36,29 @@ class ReplayMemory() :
     end = start+batchSize
     return self.states[start:end], self.actions[start:end], self.newStates[start:end], self.noNewStates[start:end], self.rewards[start:end]
 
+  def save(self, baseDir) :
+    baseName = "memory_%s_%s"%(self.fromState, self.toState)
+    torch.save(self.states, "%s/%s_states.pt"%(baseDir, baseName))
+    torch.save(self.newStates, "%s/%s_newStates.pt"%(baseDir, baseName))
+    torch.save(self.actions, "%s/%s_actions.pt"%(baseDir, baseName))
+    torch.save(self.rewards, "%s/%s_rewards.pt"%(baseDir, baseName))
+    torch.save(self.noNewStates, "%s/%s_noNewStates.pt"%(baseDir, baseName))
+    json.dump([self.capacity, self.position, self.nbPushed], open("%s/%s.json"%(baseDir, baseName), "w"))
+
+  def load(self, baseDir) :
+    baseName = "memory_%s_%s"%(self.fromState, self.toState)
+    if not os.path.isfile("%s/%s.json"%(baseDir, baseName)) :
+      return
+    self.states = torch.load("%s/%s_states.pt"%(baseDir, baseName))
+    self.newStates = torch.load("%s/%s_newStates.pt"%(baseDir, baseName))
+    self.actions = torch.load("%s/%s_actions.pt"%(baseDir, baseName))
+    self.rewards = torch.load("%s/%s_rewards.pt"%(baseDir, baseName))
+    self.noNewStates = torch.load("%s/%s_noNewStates.pt"%(baseDir, baseName))
+    l = json.load(open("%s/%s.json"%(baseDir, baseName), "r"))
+    self.capacity = l[0]
+    self.position = l[1]
+    self.nbPushed = l[2]
+
   def __len__(self):
     return min(self.nbPushed, self.capacity)
 ################################################################################
diff --git a/Train.py b/Train.py
index fd242bb4271ec457d7465c29a48daafd1732e58f..666c925e834940cb460ff054048b90f9532a32e8 100644
--- a/Train.py
+++ b/Train.py
@@ -1,8 +1,10 @@
+import os
 import sys
 import random
 import torch
 import copy
 import math
+import json
 
 from Transition import Transition, getMissingLinks, applyTransition
 import Features
@@ -189,21 +191,30 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
 
   memory = None
   dicts = Dicts()
-  dicts.readConllu(filename, ["FORM","UPOS","LETTER","LEXICON"], 2, pretrained)
   transitionNames = {}
   for ts in transitionSets :
     for t in ts :
       transitionNames[str(t)] = (len(transitionNames), 0)
   transitionNames[dicts.nullToken] = (len(transitionNames), 0)
-  dicts.addDict("HISTORY", transitionNames)
-  dicts.save(modelDir + "/dicts.json")
-
-  policy_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained).to(getDevice())
-  target_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained).to(getDevice())
+  if os.path.isfile(modelDir+"/dicts.json") :
+    dicts.load(modelDir+"/dicts.json")
+  else :
+    dicts.readConllu(filename, ["FORM","UPOS","LETTER","LEXICON"], 2, pretrained)
+    dicts.addDict("HISTORY", transitionNames)
+    dicts.save(modelDir + "/dicts.json")
+
+  if os.path.isfile(modelDir+"/lastNetwork.pt") :
+    policy_net = torch.load(modelDir+"/lastNetwork.pt")
+    target_net = torch.load(modelDir+"/lastNetwork.pt")
+  else :
+    policy_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained).to(getDevice())
+    target_net = Networks.createNetwork(networkName, dicts, [len(transitionSet) for transitionSet in transitionSets], incremental, pretrained).to(getDevice())
   target_net.load_state_dict(policy_net.state_dict())
   target_net.eval()
   policy_net.train()
   optimizer = torch.optim.Adam(policy_net.parameters(), lr=lr)
+  if os.path.isfile(modelDir+"/optimizer.pt") :
+    optimizer.load_state_dict(torch.load(modelDir+"/optimizer.pt"))
   print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(policy_net)), 3)), file=sys.stderr)
 
   bestLoss = None
@@ -213,7 +224,13 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
   nbExByEpoch = sum(map(len,sentences))
   sentIndex = 0
 
-  for epoch in range(1,nbIter+1) :
+  startingEpoch = 1
+  if os.path.isfile(modelDir+"/epoch.json") :
+    l = json.load(open(modelDir+"/epoch.json", "r"))
+    startingEpoch = l[0]+1
+    bestLoss = l[1]
+    bestScore = l[2]
+  for epoch in range(startingEpoch,nbIter+1) :
     i = 0
     totalLoss = 0.0
     while True :
@@ -242,7 +259,6 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
         probaRandom = list_probas[fromState][0]
         probaOracle = list_probas[fromState][1]
         
-
         if debug :
           sentence.printForDebug(sys.stderr)
         action = selectAction(policy_net, state, transitionSet, sentence, missingLinks, probaRandom, probaOracle, fromState)
@@ -268,6 +284,9 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
 
         if memory is None :
           memory = [[ReplayMemory(5000, state.numel(), f, t) for t in range(len(transitionSets))] for f in range(len(transitionSets))]
+          for fr in memory :
+            for mem in fr :
+              mem.load(modelDir)
         memory[fromState][toState].push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
         state = newState
         if i % batchSize == 0 :
@@ -284,5 +303,11 @@ def trainModelRl(debug, networkName, modelDir, filename, nbIter, batchSize, devF
         break
       sentIndex += 1
     bestLoss, bestScore = evalModelAndSave(debug, policy_net, transitionSets, strategy, dicts, modelDir, devFile, bestLoss, totalLoss, bestScore, epoch, nbIter, incremental, rewardFunc, predicted)
+    torch.save(optimizer.state_dict(), modelDir+"/optimizer.pt")
+    torch.save(policy_net, modelDir+"/lastNetwork.pt")
+    for fr in memory :
+      for mem in fr :
+        mem.save(modelDir)
+    json.dump([epoch,bestLoss,bestScore], open(modelDir+"/epoch.json", "w"))
 ################################################################################