Skip to content
Snippets Groups Projects
Commit a4487b91 authored by Franck Dary's avatar Franck Dary
Browse files

Sending torch object to CPU before saving them to disk

parent 51f41f87
No related branches found
No related tags found
No related merge requests found
...@@ -229,12 +229,16 @@ std::string Classifier::getLastFilename() const ...@@ -229,12 +229,16 @@ std::string Classifier::getLastFilename() const
void Classifier::saveBest() void Classifier::saveBest()
{ {
getNN()->to(torch::kCPU);
torch::save(getNN(), getBestFilename()); torch::save(getNN(), getBestFilename());
getNN()->to(NeuralNetworkImpl::device);
} }
void Classifier::saveLast() void Classifier::saveLast()
{ {
getNN()->to(torch::kCPU);
torch::save(getNN(), getLastFilename()); torch::save(getNN(), getLastFilename());
getNN()->to(NeuralNetworkImpl::device);
saveOptimizer(); saveOptimizer();
} }
...@@ -296,7 +296,7 @@ void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem: ...@@ -296,7 +296,7 @@ void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem:
int nbClasses = classes[0].size(0); int nbClasses = classes[0].size(0);
auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1); auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1).to(torch::kCPU);
auto filename = fmt::format("{}-{}_{}-{}.{}.{}.tensor", state, nbClasses, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle); auto filename = fmt::format("{}-{}_{}-{}.{}.{}.tensor", state, nbClasses, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle);
torch::save(tensorToSave, dir/filename); torch::save(tensorToSave, dir/filename);
lastSavedIndex = currentExampleIndex; lastSavedIndex = currentExampleIndex;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment