diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp
index 439cbcc6d32abfeeff17fd090f15bf34a3b3f438..35f1942979d4a0797ec1ed57d7d96abd8784ccdb 100644
--- a/torch_modules/src/ConfigDataset.cpp
+++ b/torch_modules/src/ConfigDataset.cpp
@@ -1,4 +1,5 @@
 #include "ConfigDataset.hpp"
+#include "NeuralNetwork.hpp"
 
 ConfigDataset::ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes)
 {
@@ -24,6 +25,6 @@ torch::optional<size_t> ConfigDataset::size() const
 
 torch::data::Example<> ConfigDataset::get(size_t index)
 {
-  return {data.narrow(0, index*(contextSize+1), contextSize), data.narrow(0, index*(contextSize+1)+contextSize, 1)};
+  return {data.narrow(0, index*(contextSize+1), contextSize).to(NeuralNetworkImpl::device), data.narrow(0, index*(contextSize+1)+contextSize, 1).to(NeuralNetworkImpl::device)};
 }
 
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 6ebf18e8ab57e80ec084000bf23b5b3c5bfa9638..ccbad4e1fb0e88e60ccb9bc1fabef5f50b77b0aa 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -54,14 +54,14 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch:
     {
       context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
       for (auto & element : context)
-        contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device));
+        contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone());
     } catch(std::exception & e)
     {
       util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
     }
 
     int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
-    auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
+    auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
     gold[0] = goldIndex;
 
     for (auto & element : context)