From a4487b916588ab7a630197d4df519b05bc8035ee Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 7 Jul 2020 11:47:55 +0200
Subject: [PATCH] Sending torch object to CPU before saving them to disk

---
 reading_machine/src/Classifier.cpp | 4 ++++
 trainer/src/Trainer.cpp            | 2 +-
 2 files changed, 5 insertions(+), 1 deletion(-)

diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index c685fa2..704b3eb 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -229,12 +229,16 @@ std::string Classifier::getLastFilename() const
 
 void Classifier::saveBest()
 {
+  getNN()->to(torch::kCPU);
   torch::save(getNN(), getBestFilename());
+  getNN()->to(NeuralNetworkImpl::device);
 }
 
 void Classifier::saveLast()
 {
+  getNN()->to(torch::kCPU);
   torch::save(getNN(), getLastFilename());
+  getNN()->to(NeuralNetworkImpl::device);
   saveOptimizer();
 }
 
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index f161778..32932eb 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -296,7 +296,7 @@ void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem:
 
   int nbClasses = classes[0].size(0);
 
-  auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
+  auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1).to(torch::kCPU);
   auto filename = fmt::format("{}-{}_{}-{}.{}.{}.tensor", state, nbClasses, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle);
   torch::save(tensorToSave, dir/filename);
   lastSavedIndex = currentExampleIndex;
-- 
GitLab