From f93305a5459af6b1aa689f1701391aa402c756e0 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 3 Mar 2020 17:53:38 +0100
Subject: [PATCH] Detect device (cpu or cuda) and use cuda if possible

---
 CMakeLists.txt                          | 3 +--
 decoder/src/Decoder.cpp                 | 2 +-
 decoder/src/macaon_decode.cpp           | 2 ++
 reading_machine/src/Classifier.cpp      | 3 +++
 torch_modules/include/NeuralNetwork.hpp | 4 ++++
 torch_modules/src/NeuralNetwork.cpp     | 2 ++
 trainer/src/Trainer.cpp                 | 4 ++--
 trainer/src/macaon_train.cpp            | 3 +++
 8 files changed, 18 insertions(+), 5 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 78867d2..7c6507c 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,5 +1,5 @@
 cmake_minimum_required(VERSION 3.0.2)
-project(test_torch)
+project(macaon)
 
 add_compile_definitions(BOOST_DISABLE_THREADS)
 
@@ -9,7 +9,6 @@ find_package(Boost 1.53.0 REQUIRED COMPONENTS program_options)
 include_directories(SYSTEM ${TORCH_INCLUDE_DIRS})
 
 add_library(Torch SHARED IMPORTED)
-set_target_properties(Torch PROPERTIES IMPORTED_LOCATION ${TORCH_LIBRARIES})
 add_library(Boost SHARED IMPORTED)
 set_target_properties(Boost PROPERTIES IMPORTED_LOCATION ${Boost_PROGRAM_OPTIONS_LIBRARY_RELEASE})
 
diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index e7b2285..fc5f386 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -23,7 +23,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
     auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState()));
     machine.getDict(config.getState()).setState(dictState);
 
-    auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong);
+    auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone().to(NeuralNetworkImpl::device);
     auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
 
     int chosenTransition = -1;
diff --git a/decoder/src/macaon_decode.cpp b/decoder/src/macaon_decode.cpp
index f673c2c..5b7272b 100644
--- a/decoder/src/macaon_decode.cpp
+++ b/decoder/src/macaon_decode.cpp
@@ -78,6 +78,8 @@ int main(int argc, char * argv[])
   if (modelPaths.empty())
     util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, "")));
 
+  fmt::print(stderr, "Decoding using device : {}\n", NeuralNetworkImpl::device.str());
+
   try
   {
     ReadingMachine machine(machinePath, modelPaths, dictPaths);
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index f234c73..ad03c31 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -67,7 +67,10 @@ void Classifier::initNeuralNetwork(const std::string & topology)
 
   for (auto & initializer : initializers)
     if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer)))
+    {
+      this->nn->to(NeuralNetworkImpl::device);
       return;
+    }
 
   std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology);
   for (auto & initializer : initializers)
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index cdc5514..4709216 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -7,6 +7,10 @@
 
 class NeuralNetworkImpl : public torch::nn::Module
 {
+  public :
+
+  static torch::Device device;
+
   protected : 
 
   int leftBorder{5};
diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
index cdb8ad3..0719d84 100644
--- a/torch_modules/src/NeuralNetwork.cpp
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -1,5 +1,7 @@
 #include "NeuralNetwork.hpp"
 
+torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
+
 std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const
 {
   std::stack<int> leftContext;
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index a74078c..191ac47 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -26,10 +26,10 @@ void Trainer::createDataset(SubConfig & config, bool debug)
     }
 
     auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
-    contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
+    contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(at::kLong)).clone().to(NeuralNetworkImpl::device));
 
     int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
-    auto gold = torch::zeros(1, at::kLong);
+    auto gold = torch::zeros(1, torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device));
     gold[0] = goldIndex;
 
     classes.emplace_back(gold);
diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp
index 34a009a..70aee20 100644
--- a/trainer/src/macaon_train.cpp
+++ b/trainer/src/macaon_train.cpp
@@ -3,6 +3,7 @@
 #include "util.hpp"
 #include "Trainer.hpp"
 #include "Decoder.hpp"
+#include "NeuralNetwork.hpp"
 
 namespace po = boost::program_options;
 
@@ -73,6 +74,8 @@ int main(int argc, char * argv[])
   auto nbEpoch = variables["nbEpochs"].as<int>();
   bool debug = variables.count("debug") == 0 ? false : true;
 
+  fmt::print(stderr, "Training using device : {}\n", NeuralNetworkImpl::device.str());
+
   try
   {
 
-- 
GitLab