diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index c685fa200e6185810b524e7698f1baa6113868f4..704b3ebe13c0ca83f7636e19d3f867f87d584996 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -229,12 +229,16 @@ std::string Classifier::getLastFilename() const
 
 void Classifier::saveBest()
 {
+  getNN()->to(torch::kCPU);
   torch::save(getNN(), getBestFilename());
+  getNN()->to(NeuralNetworkImpl::device);
 }
 
 void Classifier::saveLast()
 {
+  getNN()->to(torch::kCPU);
   torch::save(getNN(), getLastFilename());
+  getNN()->to(NeuralNetworkImpl::device);
   saveOptimizer();
 }
 
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index f161778e464b3f447fa9897ab60899266fc5b68a..32932eb1f6d404387253fdea1a2d6e37c44aaa30 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -296,7 +296,7 @@ void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem:
 
   int nbClasses = classes[0].size(0);
 
-  auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
+  auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1).to(torch::kCPU);
   auto filename = fmt::format("{}-{}_{}-{}.{}.{}.tensor", state, nbClasses, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle);
   torch::save(tensorToSave, dir/filename);
   lastSavedIndex = currentExampleIndex;