diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index c685fa200e6185810b524e7698f1baa6113868f4..704b3ebe13c0ca83f7636e19d3f867f87d584996 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -229,12 +229,16 @@ std::string Classifier::getLastFilename() const void Classifier::saveBest() { + getNN()->to(torch::kCPU); torch::save(getNN(), getBestFilename()); + getNN()->to(NeuralNetworkImpl::device); } void Classifier::saveLast() { + getNN()->to(torch::kCPU); torch::save(getNN(), getLastFilename()); + getNN()->to(NeuralNetworkImpl::device); saveOptimizer(); } diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index f161778e464b3f447fa9897ab60899266fc5b68a..32932eb1f6d404387253fdea1a2d6e37c44aaa30 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -296,7 +296,7 @@ void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem: int nbClasses = classes[0].size(0); - auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1); + auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1).to(torch::kCPU); auto filename = fmt::format("{}-{}_{}-{}.{}.{}.tensor", state, nbClasses, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle); torch::save(tensorToSave, dir/filename); lastSavedIndex = currentExampleIndex;