From 17e6ebe9d4ea7af82e9980b9a2a41125121ad3e5 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Sat, 6 Mar 2021 18:45:41 +0100
Subject: [PATCH] Macaon train decoding (devScore) on cpu in parallel

---
 decoder/include/Decoder.hpp                |  2 +-
 decoder/src/Decoder.cpp                    |  6 +++++-
 reading_machine/include/Classifier.hpp     |  1 +
 reading_machine/include/ReadingMachine.hpp |  1 +
 reading_machine/src/Classifier.cpp         |  5 +++++
 reading_machine/src/ReadingMachine.cpp     |  6 ++++++
 torch_modules/include/NeuralNetwork.hpp    |  1 +
 torch_modules/src/NeuralNetwork.cpp        |  7 ++++++-
 trainer/src/MacaonTrain.cpp                | 23 ++++++++++++++++------
 9 files changed, 43 insertions(+), 9 deletions(-)

diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp
index 01bc7cc..fe8c870 100644
--- a/decoder/include/Decoder.hpp
+++ b/decoder/include/Decoder.hpp
@@ -25,7 +25,7 @@ class Decoder
   public :
 
   Decoder(ReadingMachine & machine);
-  void decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement);
+  std::size_t decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement);
   void evaluate(const std::vector<const Config *> & configs, std::filesystem::path modelPath, const std::string goldTSV, const std::set<std::string> & predicted);
   std::vector<std::pair<float,std::string>> getF1Scores(const std::set<std::string> & colNames) const;
   std::vector<std::pair<float,std::string>> getAlignedAccs(const std::set<std::string> & colNames) const;
diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 38957af..70eba0e 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -6,11 +6,12 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
 {
 }
 
-void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement)
+std::size_t Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement)
 {
   constexpr int printInterval = 50;
 
   int nbExamplesProcessed = 0;
+  std::size_t totalNbExamplesProcessed = 0;
   auto pastTime = std::chrono::high_resolution_clock::now();
 
   Beam beam(beamSize, beamThreshold, baseConfig, machine);
@@ -20,6 +21,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh
     while (!beam.isEnded())
     {
       beam.update(machine, debug);
+      ++totalNbExamplesProcessed;
 
       if (printAdvancement)
         if (++nbExamplesProcessed >= printInterval)
@@ -49,6 +51,8 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh
   // Fill holes in important columns like "ID" and "HEAD" to be compatible with eval script
   try {baseConfig.addMissingColumns();}
   catch (std::exception & e) {util::myThrow(e.what());}
+
+  return totalNbExamplesProcessed;
 }
 
 float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const
diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp
index e4b2208..41285a3 100644
--- a/reading_machine/include/Classifier.hpp
+++ b/reading_machine/include/Classifier.hpp
@@ -54,6 +54,7 @@ class Classifier
   bool isRegression() const;
   LossFunction & getLossFunction();
   bool exampleIsBanned(const Config & config);
+  void to(torch::Device device);
 };
 
 #endif
diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp
index 3135635..5035070 100644
--- a/reading_machine/include/ReadingMachine.hpp
+++ b/reading_machine/include/ReadingMachine.hpp
@@ -52,6 +52,7 @@ class ReadingMachine
   void loadPretrainedClassifiers();
   int getNbParameters() const;
   void resetOptimizers();
+  void to(torch::Device device);
 };
 
 #endif
diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index 68e89b7..a2361c8 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -300,3 +300,8 @@ bool Classifier::exampleIsBanned(const Config & config)
   return false;
 }
 
+void Classifier::to(torch::Device device)
+{
+  getNN()->to(device);
+}
+
diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp
index 33e08cd..7c06ebd 100644
--- a/reading_machine/src/ReadingMachine.cpp
+++ b/reading_machine/src/ReadingMachine.cpp
@@ -194,3 +194,9 @@ void ReadingMachine::resetOptimizers()
     classifier->resetOptimizer();
 }
 
+void ReadingMachine::to(torch::Device device)
+{
+  for (auto & classifier : classifiers)
+    classifier->to(device);
+}
+
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index ffbcdea..d96f264 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -23,6 +23,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
   virtual void setCountOcc(bool countOcc) = 0;
   virtual void removeRareDictElements(float rarityThreshold) = 0;
 
+  static torch::Device getPreferredDevice();
   static float entropy(torch::Tensor probabilities);
 };
 TORCH_MODULE(NeuralNetwork);
diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp
index 785c8d9..c85c160 100644
--- a/torch_modules/src/NeuralNetwork.cpp
+++ b/torch_modules/src/NeuralNetwork.cpp
@@ -1,6 +1,6 @@
 #include "NeuralNetwork.hpp"
 
-torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
+torch::Device NeuralNetworkImpl::device(getPreferredDevice());
 
 float NeuralNetworkImpl::entropy(torch::Tensor probabilities)
 {
@@ -13,3 +13,8 @@ float NeuralNetworkImpl::entropy(torch::Tensor probabilities)
   return entropy;
 }
 
+torch::Device NeuralNetworkImpl::getPreferredDevice()
+{
+  return torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
+}
+
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 900ab29..efc6341 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -47,7 +47,7 @@ po::options_description MacaonTrain::getOptionsDescription()
     ("maxNorm", po::value<float>()->default_value(std::numeric_limits<float>::max()),
       "Max norm for the embeddings")
     ("lockPretrained", "Disable fine tuning of all pretrained word embeddings.")
-    ("lineByLine", "Treat the TXT input as being one different text per line.")
+    ("lineByLine", "Process the TXT input as being one different text per line.")
     ("help,h", "Produce this help message")
     ("oracleMode", "Don't train a model, transform the corpus into a sequence of transitions.");
 
@@ -323,11 +323,22 @@ int MacaonTrain::main()
       machine.trainMode(false);
       machine.setDictsState(Dict::State::Closed);
 
-      std::for_each(std::execution::par_unseq, devConfigs.begin(), devConfigs.end(),
-        [&decoder, &debug, &printAdvancement](BaseConfig & devConfig)
-        {
-          decoder.decode(devConfig, 1, 0.0, debug, printAdvancement);
-        });
+      if (devConfigs.size() > 1)
+      {
+        NeuralNetworkImpl::device = torch::kCPU;
+        machine.to(NeuralNetworkImpl::device);
+        std::for_each(std::execution::par_unseq, devConfigs.begin(), devConfigs.end(),
+          [&decoder, debug, printAdvancement](BaseConfig & devConfig)
+          {
+            decoder.decode(devConfig, 1, 0.0, debug, printAdvancement);
+          });
+        NeuralNetworkImpl::device = NeuralNetworkImpl::getPreferredDevice();
+        machine.to(NeuralNetworkImpl::device);
+      }
+      else
+      {
+        decoder.decode(devConfigs[0], 1, 0.0, debug, printAdvancement);
+      }
 
       std::vector<const Config *> devConfigsPtrs;
       for (auto & devConfig : devConfigs)
-- 
GitLab