From a4487b916588ab7a630197d4df519b05bc8035ee Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 7 Jul 2020 11:47:55 +0200 Subject: [PATCH] Sending torch object to CPU before saving them to disk --- reading_machine/src/Classifier.cpp | 4 ++++ trainer/src/Trainer.cpp | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index c685fa2..704b3eb 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -229,12 +229,16 @@ std::string Classifier::getLastFilename() const void Classifier::saveBest() { + getNN()->to(torch::kCPU); torch::save(getNN(), getBestFilename()); + getNN()->to(NeuralNetworkImpl::device); } void Classifier::saveLast() { + getNN()->to(torch::kCPU); torch::save(getNN(), getLastFilename()); + getNN()->to(NeuralNetworkImpl::device); saveOptimizer(); } diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index f161778..32932eb 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -296,7 +296,7 @@ void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem: int nbClasses = classes[0].size(0); - auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1); + auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1).to(torch::kCPU); auto filename = fmt::format("{}-{}_{}-{}.{}.{}.tensor", state, nbClasses, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle); torch::save(tensorToSave, dir/filename); lastSavedIndex = currentExampleIndex; -- GitLab