diff --git a/Decode.py b/Decode.py
index af8fc687cad2236c42cdf36a7f724732e694e0ec..300572a1a799e50bd5a525d413f1e6a5c3f0a585 100644
--- a/Decode.py
+++ b/Decode.py
@@ -3,6 +3,7 @@ import sys
 from Transition import Transition, getMissingLinks, applyTransition
 from Features import extractFeatures
 from Dicts import Dicts
+from Util import getDevice
 import Config
 import torch
 
@@ -48,9 +49,14 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
   config.moveWordIndex(0)
   moved = True
   network.eval()
+
+  currentDevice = network.currentDevice()
+  decodeDevice = getDevice()
+  network.to(decodeDevice)
+
   with torch.no_grad():
     while moved :
-      features = extractFeatures(dicts, config).unsqueeze(0)
+      features = extractFeatures(dicts, config).unsqueeze(0).to(decodeDevice)
       output = torch.nn.functional.softmax(network(features), dim=1)
       candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index].name] for index in range(len(ts))])[::-1]
       candidates = [cand[2] for cand in candidates if cand[0]]
@@ -63,6 +69,8 @@ def decodeModel(ts, strat, config, network, dicts, debug) :
       moved = applyTransition(ts, strat, config, candidate)
 
   EOS.apply(config)
+
+  network.to(currentDevice)
 ################################################################################
 
 ################################################################################
diff --git a/Networks.py b/Networks.py
index 2ab865bf7b1dd6f9909a54f8734be0f150f8926c..d1beadec23625290e556e5d3329378d868b60dd3 100644
--- a/Networks.py
+++ b/Networks.py
@@ -1,3 +1,4 @@
+import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
@@ -5,18 +6,31 @@ import torch.nn.functional as F
 class BaseNet(nn.Module):
   def __init__(self, dicts, inputSize, outputSize) :
     super().__init__()
+    self.dummyParam = nn.Parameter(torch.empty(0), requires_grad=False)
+
     self.embSize = 64
     self.inputSize = inputSize
     self.outputSize = outputSize
-    self.embeddings = {name : nn.Embedding(len(dicts.dicts[name]), self.embSize) for name in dicts.dicts.keys()}
+    for name in dicts.dicts :
+      self.add_module("emb_"+name, nn.Embedding(len(dicts.dicts[name]), self.embSize))
     self.fc1 = nn.Linear(inputSize * self.embSize, 1600)
     self.fc2 = nn.Linear(1600, outputSize)
     self.dropout = nn.Dropout(0.3)
 
+    self.apply(self.initWeights)
+
   def forward(self, x) :
-    x = self.dropout(self.embeddings["UPOS"](x).view(x.size(0), -1))
+    x = self.dropout(getattr(self, "emb_"+"UPOS")(x).view(x.size(0), -1))
     x = F.relu(self.dropout(self.fc1(x)))
     x = self.fc2(x)
     return x
+
+  def currentDevice(self) :
+    return self.dummyParam.device
+
+  def initWeights(self,m) :
+    if type(m) == nn.Linear:
+      torch.nn.init.xavier_uniform_(m.weight)
+      m.bias.data.fill_(0.01)
 ################################################################################
 
diff --git a/Rl.py b/Rl.py
index 8382e315862758b99f174cbeb16e55601daa53cc..59bdf4d8855d3ac1e8f425093730663278c0fd36 100644
--- a/Rl.py
+++ b/Rl.py
@@ -1,15 +1,17 @@
+import sys
 import random
 import torch
 import torch.nn.functional as F
+from Util import getDevice
 
 ################################################################################
 class ReplayMemory() :
   def __init__(self, capacity, stateSize) :
     self.capacity = capacity
-    self.states = torch.zeros(capacity, stateSize, dtype=torch.long)
-    self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long)
-    self.actions = torch.zeros(capacity, 1, dtype=torch.long)
-    self.rewards = torch.zeros(capacity, 1)
+    self.states = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
+    self.newStates = torch.zeros(capacity, stateSize, dtype=torch.long, device=getDevice())
+    self.actions = torch.zeros(capacity, 1, dtype=torch.long, device=getDevice())
+    self.rewards = torch.zeros(capacity, 1, device=getDevice())
     self.position = 0
     self.nbPushed = 0
 
@@ -45,7 +47,6 @@ def selectAction(network, state, ts, config, missingLinks, probaRandom, probaOra
       candidates = sorted([[ts[index].appliable(config), "%.2f"%float(output[0][index]), ts[index]] for index in range(len(ts))])[::-1]
       candidates = [cand[2] for cand in candidates if cand[0]]
       return candidates[0] if len(candidates) > 0 else None
-
 ################################################################################
 
 ################################################################################
@@ -53,11 +54,11 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
   gamma = 0.999
   if len(memory) < batchSize :
     return 0.0
-  
+
   states, actions, nextStates, rewards = memory.sample(batchSize)
 
   predictedQ = policy_net(states).gather(1, actions)
-  nextQ = target_net(nextStates).max(1)[0].unsqueeze(0)
+  nextQ = target_net(nextStates).max(1)[0].detach().unsqueeze(0)
   nextQ = torch.transpose(nextQ, 0, 1)
 
   expectedReward = gamma*nextQ + rewards
@@ -65,8 +66,11 @@ def optimizeModel(batchSize, policy_net, target_net, memory, optimizer) :
   loss = F.smooth_l1_loss(predictedQ, expectedReward)
   optimizer.zero_grad()
   loss.backward()
-
+  for param in policy_net.parameters() :
+    if param.grad is not None :
+      param.grad.data.clamp_(-1, 1)
   optimizer.step()
+
   return float(loss)
 ################################################################################
 
diff --git a/Train.py b/Train.py
index 838a71aac6309fd6d2e377370f8e3df7eaaefe96..6a28f41564d2b9eb214d18d7e41d97b041a44466 100644
--- a/Train.py
+++ b/Train.py
@@ -6,7 +6,7 @@ import copy
 from Transition import Transition, getMissingLinks, applyTransition
 import Features
 from Dicts import Dicts
-from Util import timeStamp, prettyInt, numParameters
+from Util import timeStamp, prettyInt, numParameters, getDevice
 from Rl import ReplayMemory, selectAction, optimizeModel
 import Networks
 import Decode
@@ -93,7 +93,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
   print("%s : Extracted %s examples"%(timeStamp(), prettyInt(len(examples), 3)), file=sys.stderr)
   examples = torch.stack(examples)
 
-  network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet))
+  network = Networks.BaseNet(dicts, examples[0].size(0)-1, len(transitionSet)).to(getDevice())
   print("%s : Model has %s parameters"%(timeStamp(), prettyInt((numParameters(network)), 3)), file=sys.stderr)
   optimizer = torch.optim.Adam(network.parameters(), lr=0.0001)
   lossFct = torch.nn.CrossEntropyLoss()
@@ -107,7 +107,7 @@ def trainModelOracle(debug, modelDir, filename, nbEpochs, batchSize, devFile, tr
     printInterval = 2000
     advancement = 0
     for batchIndex in range(0,examples.size(0)-batchSize,batchSize) :
-      batch = examples[batchIndex:batchIndex+batchSize]
+      batch = examples[batchIndex:batchIndex+batchSize].to(getDevice())
       targets = batch[:,:1].view(-1)
       inputs = batch[:,1:]
       nbEx += targets.size(0)
@@ -149,11 +149,11 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
         print("Curent epoch %6.2f%%"%(100.0*sentIndex/len(sentences)), end="\r", file=sys.stderr)
       sentence = sentences[sentIndex]
       sentence.moveWordIndex(0)
-      state = Features.extractFeaturesPosExtended(dicts, sentence)
+      state = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice())
 
       if policy_net is None :
-        policy_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet))
-        target_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet))
+        policy_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet)).to(getDevice())
+        target_net = Networks.BaseNet(dicts, state.numel(), len(transitionSet)).to(getDevice())
         target_net.load_state_dict(policy_net.state_dict())
         target_net.eval()
         policy_net.train()
@@ -169,14 +169,14 @@ def trainModelRl(debug, modelDir, filename, nbIter, batchSize, devFile, transiti
           break
 
         reward = -1.0*action.getOracleScore(sentence, missingLinks)
-        reward = torch.FloatTensor([reward])
+        reward = torch.FloatTensor([reward]).to(getDevice())
 
         applyTransition(transitionSet, strategy, sentence, action.name)
-        newState = Features.extractFeaturesPosExtended(dicts, sentence)
+        newState = Features.extractFeaturesPosExtended(dicts, sentence).to(getDevice())
 
         if memory is None :
           memory = ReplayMemory(1000, state.numel())
-        memory.push(state, torch.LongTensor([transitionSet.index(action)]), newState, reward)
+        memory.push(state, torch.LongTensor([transitionSet.index(action)]).to(getDevice()), newState, reward)
         state = newState
         if i % batchSize == 0 :
           totalLoss += optimizeModel(batchSize, policy_net, target_net, memory, optimizer)
diff --git a/Util.py b/Util.py
index 362bb6d00257192db14e78202b165d96aaa40286..137867f9cd8a431ca92be3d02f9fc2f457d753ad 100644
--- a/Util.py
+++ b/Util.py
@@ -1,5 +1,19 @@
 from datetime import datetime
 
+favoriteDevice = None
+
+################################################################################
+def setDevice(device) :
+  global favoriteDevice
+  favoriteDevice = device
+################################################################################
+
+################################################################################
+def getDevice() :
+  global favoriteDevice
+  return favoriteDevice
+################################################################################
+
 ################################################################################
 def timeStamp() :
   return "[%s]"%datetime.now().strftime("%H:%M:%S")
diff --git a/main.py b/main.py
index ad8098b9e2c2e6b75cc35b3017a3a000ebdbdb76..43758e05b0e320c5bd9ee125e4489bbbe35bb71a 100755
--- a/main.py
+++ b/main.py
@@ -6,6 +6,7 @@ import argparse
 import random
 import torch
 
+import Util
 import Train
 import Decode
 
@@ -35,6 +36,9 @@ if __name__ == "__main__" :
   args = parser.parse_args()
 
   os.makedirs(args.model, exist_ok=True)
+
+  Util.setDevice(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
+  print("Using device : %s"%Util.getDevice())
   random.seed(args.seed)
   torch.manual_seed(args.seed)