Commit a4487b91 authored by Franck Dary's avatar Franck Dary
Browse files

Sending torch object to CPU before saving them to disk

parent 51f41f87
......@@ -229,12 +229,16 @@ std::string Classifier::getLastFilename() const
void Classifier::saveBest()
{
getNN()->to(torch::kCPU);
torch::save(getNN(), getBestFilename());
getNN()->to(NeuralNetworkImpl::device);
}
void Classifier::saveLast()
{
getNN()->to(torch::kCPU);
torch::save(getNN(), getLastFilename());
getNN()->to(NeuralNetworkImpl::device);
saveOptimizer();
}
......@@ -296,7 +296,7 @@ void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem:
int nbClasses = classes[0].size(0);
auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1).to(torch::kCPU);
auto filename = fmt::format("{}-{}_{}-{}.{}.{}.tensor", state, nbClasses, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle);
torch::save(tensorToSave, dir/filename);
lastSavedIndex = currentExampleIndex;
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment