diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp
index 64dda232cccdadb8031433f5c5562f79e6e8db06..22e715f505d66ed9d41c32894d10d3d8395f405f 100644
--- a/decoder/src/MacaonDecode.cpp
+++ b/decoder/src/MacaonDecode.cpp
@@ -90,7 +90,7 @@ int MacaonDecode::main()
   if (modelPaths.empty())
     util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, "")));
 
-  fmt::print(stderr, "Decoding using device : {}\n", NeuralNetworkImpl::device.str());
+  fmt::print(stderr, "Decoding using device : {}\n", NeuralNetworkImpl::getDevice().str());
 
   try
   {
@@ -127,8 +127,8 @@ int MacaonDecode::main()
 
     if (configs.size() > 1)
     {
-      NeuralNetworkImpl::device = torch::kCPU;
-      machine.to(NeuralNetworkImpl::device);
+      NeuralNetworkImpl::setDevice(torch::kCPU);
+      machine.to(NeuralNetworkImpl::getDevice());
       std::for_each(std::execution::par, configs.begin(), configs.end(),
         [&decoder, debug, printAdvancement, beamSize, beamThreshold](BaseConfig & config)
         {
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index a2361c8ae69f36dfb6a9d60659a59a90c52530a5..a9de8bbaabc1c3f20a2987e65cfb8094aaf470fc 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -83,20 +83,21 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
 
   if (!train)
   {
-    torch::load(getNN(), getBestFilename(), NeuralNetworkImpl::device);
+    fmt::print(stderr, "Before load on {}\n", NeuralNetworkImpl::getDevice() == torch::kCPU ? "cpu" : "gpu");
+    torch::load(getNN(), getBestFilename(), NeuralNetworkImpl::getDevice());
+    fmt::print(stderr, "After load\n");
     getNN()->registerEmbeddings();
-    getNN()->to(NeuralNetworkImpl::device);
+    getNN()->to(NeuralNetworkImpl::getDevice());
   }
   else if (std::filesystem::exists(getLastFilename()))
   {
-    torch::load(getNN(), getLastFilename(), NeuralNetworkImpl::device);
-    getNN()->to(NeuralNetworkImpl::device);
+    torch::load(getNN(), getLastFilename(), NeuralNetworkImpl::getDevice());
     resetOptimizer();
     loadOptimizer();
   }
   else
   {
-    getNN()->to(NeuralNetworkImpl::device);
+    getNN()->to(NeuralNetworkImpl::getDevice());
   }
 }
 
@@ -183,7 +184,7 @@ void Classifier::loadOptimizer()
 {
   auto optimizerPath = std::filesystem::path(fmt::format("{}/{}_optimizer.pt", path.string(), name));
   if (std::filesystem::exists(optimizerPath))
-    torch::load(*optimizer, optimizerPath, NeuralNetworkImpl::device);
+    torch::load(*optimizer, optimizerPath, NeuralNetworkImpl::getDevice());
 }
 
 void Classifier::saveOptimizer()
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index d96f2647cbb09650f4106148953f294e4092a65b..6e2319b6663e0611ec3dcc8b480b0f79e2fe5c25 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -8,7 +8,7 @@
 
 class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
 {
-  public :
+  private :
 
   static torch::Device device;
 
@@ -24,6 +24,8 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
   virtual void removeRareDictElements(float rarityThreshold) = 0;
 
   static torch::Device getPreferredDevice();
+  static torch::Device getDevice();
+  static void setDevice(torch::Device device);
   static float entropy(torch::Tensor probabilities);
 };
 TORCH_MODULE(NeuralNetwork);
diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp
index 91695ad8749e4c23afed51445c2ba3c83441a4ed..e8f40befef18e5b44173ad4aec91d3d3565e3171 100644
--- a/torch_modules/src/ConfigDataset.cpp
+++ b/torch_modules/src/ConfigDataset.cpp
@@ -93,8 +93,8 @@ void ConfigDataset::Holder::reset()
   loadedTensorIndex = 0;
   nextIndexToGive = 0;
   nbGiven = 0;
-  torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
-  loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device)));
+  torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::getDevice());
+  loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::getDevice())));
 }
 
 c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize)
@@ -107,8 +107,8 @@ c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset
     if (loadedTensorIndex >= (int)files.size())
       return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
     nextIndexToGive = 0;
-    torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
-    loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device)));
+    torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::getDevice());
+    loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::getDevice())));
   }
 
   int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive);
diff --git a/torch_modules/src/CustomHingeLoss.cpp b/torch_modules/src/CustomHingeLoss.cpp
index ec44b6c5bdba5aac28d314a11bfb06175f1b1f7b..e2d0b510cae18d9093181d2551b6d4a78f8b0078 100644
--- a/torch_modules/src/CustomHingeLoss.cpp
+++ b/torch_modules/src/CustomHingeLoss.cpp
@@ -3,7 +3,7 @@
 
 torch::Tensor CustomHingeLoss::operator()(torch::Tensor prediction, torch::Tensor gold)
 {
-  torch::Tensor loss = torch::zeros(1).to(NeuralNetworkImpl::device);
+  torch::Tensor loss = torch::zeros(1, NeuralNetworkImpl::getDevice());
 
   for (unsigned int i = 0; i < prediction.size(0); i++)
   {
diff --git a/torch_modules/src/LossFunction.cpp b/torch_modules/src/LossFunction.cpp
index 2f8f1be9dff754f407c5d3e4bdf5fa14e6cf825a..bb5724932bb3750188e5c92409bb50341cb2b7ee 100644
--- a/torch_modules/src/LossFunction.cpp
+++ b/torch_modules/src/LossFunction.cpp
@@ -51,13 +51,13 @@ torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::
 
   if (index == 0 or index == 2 or index == 4)
   {
-    auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
+    auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice()));
     gold[0] = goldIndexes.at(0);
     return gold;
   }
   if (index == 1 or index == 3)
   {
-    auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
+    auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice()));
     for (auto goldIndex : goldIndexes)
       gold[goldIndex] = 1;
     return gold;
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
index 1c39f186db7bdc62e6e30c3771974dacefa8af63..066b79dbec2c0ba8f09048d5160dea9814ef436e 100644
--- a/torch_modules/src/ModularNetwork.cpp
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -99,7 +99,7 @@ torch::Tensor ModularNetworkImpl::forward(torch::Tensor input, const std::string
 
 torch::Tensor ModularNetworkImpl::extractContext(Config & config)
 {
-  torch::Tensor context = torch::zeros({totalInputSize}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
+  torch::Tensor context = torch::zeros({totalInputSize}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice()));
   for (auto & mod : modules)
     mod->addToContext(context, config);
   return context;
diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
index c85c1602dc028670a88ed1dfacbee3c78e0896a0..fe3727dc18827225b17bbf6db9133b7dd67ce54e 100644
--- a/torch_modules/src/NeuralNetwork.cpp
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -18,3 +18,13 @@ torch::Device NeuralNetworkImpl::getPreferredDevice()
   return torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
 }
 
+torch::Device NeuralNetworkImpl::getDevice()
+{
+  return device;
+}
+
+void NeuralNetworkImpl::setDevice(torch::Device device)
+{
+  NeuralNetworkImpl::device = device;
+}
+
diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp
index b05d1aa00b26677500bc7f8a2acb59ea12a2cbcd..d27ffe93ed41185fb62c922278362b25f9aa84e3 100644
--- a/torch_modules/src/RandomNetwork.cpp
+++ b/torch_modules/src/RandomNetwork.cpp
@@ -10,7 +10,7 @@ torch::Tensor RandomNetworkImpl::forward(torch::Tensor input, const std::string
   if (input.dim() == 1)
     input = input.unsqueeze(0);
 
-  return torch::randn({input.size(0), (long)nbOutputsPerState[state]}, torch::TensorOptions().device(device).requires_grad(true));
+  return torch::randn({input.size(0), (long)nbOutputsPerState[state]}, torch::TensorOptions().device(NeuralNetworkImpl::getDevice()).requires_grad(true));
 }
 
 torch::Tensor RandomNetworkImpl::extractContext(Config &)
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 79d608a25c5fe7ffa51cbe03a6b8dcbf6787706d..940945180494e89df1a69ed49713eec621ffdf2f 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -161,7 +161,7 @@ int MacaonTrain::main()
     std::fclose(file);
   }
 
-  fmt::print(stderr, "[{}] Training using device : {}\n", util::getTime(), NeuralNetworkImpl::device.str());
+  fmt::print(stderr, "[{}] Training using device : {}\n", util::getTime(), NeuralNetworkImpl::getDevice().str());
 
   try
   {
@@ -325,15 +325,15 @@ int MacaonTrain::main()
 
       if (devConfigs.size() > 1)
       {
-        NeuralNetworkImpl::device = torch::kCPU;
-        machine.to(NeuralNetworkImpl::device);
+        NeuralNetworkImpl::setDevice(torch::kCPU);
+        machine.to(NeuralNetworkImpl::getDevice());
         std::for_each(std::execution::par, devConfigs.begin(), devConfigs.end(),
           [&decoder, debug, printAdvancement](BaseConfig & devConfig)
           {
             decoder.decode(devConfig, 1, 0.0, debug, printAdvancement);
           });
-        NeuralNetworkImpl::device = NeuralNetworkImpl::getPreferredDevice();
-        machine.to(NeuralNetworkImpl::device);
+        NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice());
+        machine.to(NeuralNetworkImpl::getDevice());
       }
       else
       {
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 628386c55e4867d57bb7f9787e56f87a5132ca6d..2237a07e099e038023f1ecb66fff45f7a19c83f1 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -50,8 +50,8 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
 
   std::atomic<int> totalNbExamples = 0;
 
-  NeuralNetworkImpl::device = torch::kCPU;
-  machine.to(NeuralNetworkImpl::device);
+  NeuralNetworkImpl::setDevice(torch::kCPU);
+  machine.to(NeuralNetworkImpl::getDevice());
   std::for_each(std::execution::par, configs.begin(), configs.end(),
     [this, maxNbExamplesPerFile, &examplesPerState, &totalNbExamples, debug, dynamicOracle, explorationThreshold, dir, epoch, &examplesMutex](SubConfig & config)
     {
@@ -191,8 +191,8 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
   for (auto & it : examplesPerState)
     it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle);
 
-  NeuralNetworkImpl::device = NeuralNetworkImpl::getPreferredDevice();
-  machine.to(NeuralNetworkImpl::device);
+  NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice());
+  machine.to(NeuralNetworkImpl::getDevice());
 
   std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
   if (!f)