From 5769657d611f6763d8d4b047700c75b8ed0dd066 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 1 Mar 2021 14:49:45 +0100
Subject: [PATCH] Trying to load gpu tansor onto cpu mem

---
 decoder/src/Beam.cpp                      |  2 +-
 reading_machine/src/Classifier.cpp        | 12 +++---------
 torch_modules/src/NumericColumnModule.cpp |  2 +-
 trainer/src/Trainer.cpp                   |  4 ++--
 4 files changed, 7 insertions(+), 13 deletions(-)

diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp
index 32e7360..1b0aafa 100644
--- a/decoder/src/Beam.cpp
+++ b/decoder/src/Beam.cpp
@@ -48,7 +48,7 @@ void Beam::update(ReadingMachine & machine, bool debug)
     elements[index].config.setAppliableTransitions(appliableTransitions);
 
     auto context = classifier.getNN()->extractContext(elements[index].config).back();
-    auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
+    auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
 
     auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0);
     float entropy = classifier.isRegression() ? 0.0 : NeuralNetworkImpl::entropy(prediction);
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index aa11dc9..b5e929b 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -81,17 +81,15 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
   getNN()->loadDicts(path);
   getNN()->registerEmbeddings();
 
-  getNN()->to(torch::kCPU);
-
   if (!train)
   {
-    torch::load(getNN(), getBestFilename());
+    torch::load(getNN(), getBestFilename(), NeuralNetworkImpl::device);
     getNN()->registerEmbeddings();
     getNN()->to(NeuralNetworkImpl::device);
   }
   else if (std::filesystem::exists(getLastFilename()))
   {
-    torch::load(getNN(), getLastFilename());
+    torch::load(getNN(), getLastFilename(), NeuralNetworkImpl::device);
     getNN()->to(NeuralNetworkImpl::device);
     resetOptimizer();
     loadOptimizer();
@@ -185,7 +183,7 @@ void Classifier::loadOptimizer()
 {
   auto optimizerPath = std::filesystem::path(fmt::format("{}/{}_optimizer.pt", path.string(), name));
   if (std::filesystem::exists(optimizerPath))
-    torch::load(*optimizer, optimizerPath);
+    torch::load(*optimizer, optimizerPath, NeuralNetworkImpl::device);
 }
 
 void Classifier::saveOptimizer()
@@ -273,16 +271,12 @@ std::string Classifier::getLastFilename() const
 
 void Classifier::saveBest()
 {
-  getNN()->to(torch::kCPU);
   torch::save(getNN(), getBestFilename());
-  getNN()->to(NeuralNetworkImpl::device);
 }
 
 void Classifier::saveLast()
 {
-  getNN()->to(torch::kCPU);
   torch::save(getNN(), getLastFilename());
-  getNN()->to(NeuralNetworkImpl::device);
   saveOptimizer();
 }
 
diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp
index 4899d59..15d2b19 100644
--- a/torch_modules/src/NumericColumnModule.cpp
+++ b/torch_modules/src/NumericColumnModule.cpp
@@ -46,7 +46,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st
 torch::Tensor NumericColumnModuleImpl::forward(torch::Tensor input)
 {
   auto context = input.narrow(1, firstInputIndex, getInputSize());
-  auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1).clone();
+  auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1);
   return myModule->forward(values).reshape({input.size(0), -1});
 }
 
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index d315742..a85fff0 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -93,7 +93,7 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
       if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
       {
         auto & classifier = *machine.getClassifier(config.getState());
-        auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
+        auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
         auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0);
         entropy  = NeuralNetworkImpl::entropy(prediction);
     
@@ -291,7 +291,7 @@ void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem:
 
   int nbClasses = classes[0].size(0);
 
-  auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1).to(torch::kCPU);
+  auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1);
   auto filename = fmt::format("{}-{}_{}-{}.{}.{}.tensor", state, nbClasses, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle);
   torch::save(tensorToSave, dir/filename);
   lastSavedIndex = currentExampleIndex;
-- 
GitLab