From 8c6cc68d2117451057dffbb7c4f8bbaf618fde01 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sat, 20 Mar 2021 20:18:23 +0100
Subject: [PATCH] Added functions getDevice and setDevice

---
 decoder/src/MacaonDecode.cpp            |  6 +++---
 reading_machine/src/Classifier.cpp      | 13 +++++++------
 torch_modules/include/NeuralNetwork.hpp |  4 +++-
 torch_modules/src/ConfigDataset.cpp     |  8 ++++----
 torch_modules/src/CustomHingeLoss.cpp   |  2 +-
 torch_modules/src/LossFunction.cpp      |  4 ++--
 torch_modules/src/ModularNetwork.cpp    |  2 +-
 torch_modules/src/NeuralNetwork.cpp     | 10 ++++++++++
 torch_modules/src/RandomNetwork.cpp     |  2 +-
 trainer/src/MacaonTrain.cpp             | 10 +++++-----
 trainer/src/Trainer.cpp                 |  8 ++++----
 11 files changed, 41 insertions(+), 28 deletions(-)

diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp
index 64dda23..22e715f 100644
--- a/decoder/src/MacaonDecode.cpp
+++ b/decoder/src/MacaonDecode.cpp
@@ -90,7 +90,7 @@ int MacaonDecode::main()
   if (modelPaths.empty())
     util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, "")));
 
-  fmt::print(stderr, "Decoding using device : {}\n", NeuralNetworkImpl::device.str());
+  fmt::print(stderr, "Decoding using device : {}\n", NeuralNetworkImpl::getDevice().str());
 
   try
   {
@@ -127,8 +127,8 @@ int MacaonDecode::main()
 
     if (configs.size() > 1)
     {
-      NeuralNetworkImpl::device = torch::kCPU;
-      machine.to(NeuralNetworkImpl::device);
+      NeuralNetworkImpl::setDevice(torch::kCPU);
+      machine.to(NeuralNetworkImpl::getDevice());
       std::for_each(std::execution::par, configs.begin(), configs.end(),
         [&decoder, debug, printAdvancement, beamSize, beamThreshold](BaseConfig & config)
         {
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index a2361c8..a9de8bb 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -83,20 +83,21 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
 
   if (!train)
   {
-    torch::load(getNN(), getBestFilename(), NeuralNetworkImpl::device);
+    fmt::print(stderr, "Before load on {}\n", NeuralNetworkImpl::getDevice() == torch::kCPU ? "cpu" : "gpu");
+    torch::load(getNN(), getBestFilename(), NeuralNetworkImpl::getDevice());
+    fmt::print(stderr, "After load\n");
     getNN()->registerEmbeddings();
-    getNN()->to(NeuralNetworkImpl::device);
+    getNN()->to(NeuralNetworkImpl::getDevice());
   }
   else if (std::filesystem::exists(getLastFilename()))
   {
-    torch::load(getNN(), getLastFilename(), NeuralNetworkImpl::device);
-    getNN()->to(NeuralNetworkImpl::device);
+    torch::load(getNN(), getLastFilename(), NeuralNetworkImpl::getDevice());
     resetOptimizer();
     loadOptimizer();
   }
   else
   {
-    getNN()->to(NeuralNetworkImpl::device);
+    getNN()->to(NeuralNetworkImpl::getDevice());
   }
 }
 
@@ -183,7 +184,7 @@ void Classifier::loadOptimizer()
 {
   auto optimizerPath = std::filesystem::path(fmt::format("{}/{}_optimizer.pt", path.string(), name));
   if (std::filesystem::exists(optimizerPath))
-    torch::load(*optimizer, optimizerPath, NeuralNetworkImpl::device);
+    torch::load(*optimizer, optimizerPath, NeuralNetworkImpl::getDevice());
 }
 
 void Classifier::saveOptimizer()
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index d96f264..6e2319b 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -8,7 +8,7 @@
 
 class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
 {
-  public :
+  private :
 
   static torch::Device device;
 
@@ -24,6 +24,8 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
   virtual void removeRareDictElements(float rarityThreshold) = 0;
 
   static torch::Device getPreferredDevice();
+  static torch::Device getDevice();
+  static void setDevice(torch::Device device);
   static float entropy(torch::Tensor probabilities);
 };
 TORCH_MODULE(NeuralNetwork);
diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp
index 91695ad..e8f40be 100644
--- a/torch_modules/src/ConfigDataset.cpp
+++ b/torch_modules/src/ConfigDataset.cpp
@@ -93,8 +93,8 @@ void ConfigDataset::Holder::reset()
   loadedTensorIndex = 0;
   nextIndexToGive = 0;
   nbGiven = 0;
-  torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
-  loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device)));
+  torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::getDevice());
+  loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::getDevice())));
 }
 
 c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset::Holder::get_batch(std::size_t batchSize)
@@ -107,8 +107,8 @@ c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset
     if (loadedTensorIndex >= (int)files.size())
       return c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>>();
     nextIndexToGive = 0;
-    torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::device);
-    loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device)));
+    torch::load(loadedTensor, files[loadedTensorIndex], NeuralNetworkImpl::getDevice());
+    loadedTensor = torch::index_select(loadedTensor, 0, torch::randperm(loadedTensor.size(0), torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::getDevice())));
   }
 
   int nbElementsToGive = std::min<int>(batchSize, loadedTensor.size(0)-nextIndexToGive);
diff --git a/torch_modules/src/CustomHingeLoss.cpp b/torch_modules/src/CustomHingeLoss.cpp
index ec44b6c..e2d0b51 100644
--- a/torch_modules/src/CustomHingeLoss.cpp
+++ b/torch_modules/src/CustomHingeLoss.cpp
@@ -3,7 +3,7 @@
 
 torch::Tensor CustomHingeLoss::operator()(torch::Tensor prediction, torch::Tensor gold)
 {
-  torch::Tensor loss = torch::zeros(1).to(NeuralNetworkImpl::device);
+  torch::Tensor loss = torch::zeros(1, NeuralNetworkImpl::getDevice());
 
   for (unsigned int i = 0; i < prediction.size(0); i++)
   {
diff --git a/torch_modules/src/LossFunction.cpp b/torch_modules/src/LossFunction.cpp
index 2f8f1be..bb57249 100644
--- a/torch_modules/src/LossFunction.cpp
+++ b/torch_modules/src/LossFunction.cpp
@@ -51,13 +51,13 @@ torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::
 
   if (index == 0 or index == 2 or index == 4)
   {
-    auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
+    auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice()));
     gold[0] = goldIndexes.at(0);
     return gold;
   }
   if (index == 1 or index == 3)
   {
-    auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
+    auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice()));
     for (auto goldIndex : goldIndexes)
       gold[goldIndex] = 1;
     return gold;
diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp
index 1c39f18..066b79d 100644
--- a/torch_modules/src/ModularNetwork.cpp
+++ b/torch_modules/src/ModularNetwork.cpp
@@ -99,7 +99,7 @@ torch::Tensor ModularNetworkImpl::forward(torch::Tensor input, const std::string
 
 torch::Tensor ModularNetworkImpl::extractContext(Config & config)
 {
-  torch::Tensor context = torch::zeros({totalInputSize}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
+  torch::Tensor context = torch::zeros({totalInputSize}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::getDevice()));
   for (auto & mod : modules)
     mod->addToContext(context, config);
   return context;
diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
index c85c160..fe3727d 100644
--- a/torch_modules/src/NeuralNetwork.cpp
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -18,3 +18,13 @@ torch::Device NeuralNetworkImpl::getPreferredDevice()
   return torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
 }
 
+torch::Device NeuralNetworkImpl::getDevice()
+{
+  return device;
+}
+
+void NeuralNetworkImpl::setDevice(torch::Device device)
+{
+  NeuralNetworkImpl::device = device;
+}
+
diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp
index b05d1aa..d27ffe9 100644
--- a/torch_modules/src/RandomNetwork.cpp
+++ b/torch_modules/src/RandomNetwork.cpp
@@ -10,7 +10,7 @@ torch::Tensor RandomNetworkImpl::forward(torch::Tensor input, const std::string
   if (input.dim() == 1)
     input = input.unsqueeze(0);
 
-  return torch::randn({input.size(0), (long)nbOutputsPerState[state]}, torch::TensorOptions().device(device).requires_grad(true));
+  return torch::randn({input.size(0), (long)nbOutputsPerState[state]}, torch::TensorOptions().device(NeuralNetworkImpl::getDevice()).requires_grad(true));
 }
 
 torch::Tensor RandomNetworkImpl::extractContext(Config &)
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 79d608a..9409451 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -161,7 +161,7 @@ int MacaonTrain::main()
     std::fclose(file);
   }
 
-  fmt::print(stderr, "[{}] Training using device : {}\n", util::getTime(), NeuralNetworkImpl::device.str());
+  fmt::print(stderr, "[{}] Training using device : {}\n", util::getTime(), NeuralNetworkImpl::getDevice().str());
 
   try
   {
@@ -325,15 +325,15 @@ int MacaonTrain::main()
 
       if (devConfigs.size() > 1)
       {
-        NeuralNetworkImpl::device = torch::kCPU;
-        machine.to(NeuralNetworkImpl::device);
+        NeuralNetworkImpl::setDevice(torch::kCPU);
+        machine.to(NeuralNetworkImpl::getDevice());
         std::for_each(std::execution::par, devConfigs.begin(), devConfigs.end(),
           [&decoder, debug, printAdvancement](BaseConfig & devConfig)
           {
             decoder.decode(devConfig, 1, 0.0, debug, printAdvancement);
           });
-        NeuralNetworkImpl::device = NeuralNetworkImpl::getPreferredDevice();
-        machine.to(NeuralNetworkImpl::device);
+        NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice());
+        machine.to(NeuralNetworkImpl::getDevice());
       }
       else
       {
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 628386c..2237a07 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -50,8 +50,8 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
 
   std::atomic<int> totalNbExamples = 0;
 
-  NeuralNetworkImpl::device = torch::kCPU;
-  machine.to(NeuralNetworkImpl::device);
+  NeuralNetworkImpl::setDevice(torch::kCPU);
+  machine.to(NeuralNetworkImpl::getDevice());
   std::for_each(std::execution::par, configs.begin(), configs.end(),
     [this, maxNbExamplesPerFile, &examplesPerState, &totalNbExamples, debug, dynamicOracle, explorationThreshold, dir, epoch, &examplesMutex](SubConfig & config)
     {
@@ -191,8 +191,8 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
   for (auto & it : examplesPerState)
     it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle);
 
-  NeuralNetworkImpl::device = NeuralNetworkImpl::getPreferredDevice();
-  machine.to(NeuralNetworkImpl::device);
+  NeuralNetworkImpl::setDevice(NeuralNetworkImpl::getPreferredDevice());
+  machine.to(NeuralNetworkImpl::getDevice());
 
   std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w");
   if (!f)
-- 
GitLab