From 3ee8137bc1baf95129781919aa7b1986083d92d2 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Thu, 13 Feb 2020 18:08:02 +0100
Subject: [PATCH] Added debug mode

---
 decoder/include/Decoder.hpp   |  2 +-
 decoder/src/Decoder.cpp       |  5 ++++-
 decoder/src/macaon_decode.cpp |  4 +++-
 trainer/include/Trainer.hpp   |  4 ++--
 trainer/src/Trainer.cpp       | 24 +++++++++++++++---------
 trainer/src/macaon_train.cpp  | 18 +++++++++++++-----
 6 files changed, 38 insertions(+), 19 deletions(-)

diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp
index 90f3f20..e8f7ecf 100644
--- a/decoder/include/Decoder.hpp
+++ b/decoder/include/Decoder.hpp
@@ -15,7 +15,7 @@ class Decoder
   public :
 
   Decoder(ReadingMachine & machine);
-  void decode(BaseConfig & config, std::size_t beamSize);
+  void decode(BaseConfig & config, std::size_t beamSize, bool debug);
   void evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV);
   float getMetricScore(const std::string & metric, std::size_t scoreIndex);
   float getPrecision(const std::string & metric);
diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 543dbbf..78c3b86 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -5,7 +5,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
 {
 }
 
-void Decoder::decode(BaseConfig & config, std::size_t beamSize)
+void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
 {
   config.addPredicted(machine.getPredicted());
 
@@ -15,6 +15,9 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize)
 
   while (true)
   {
+    if (debug)
+      config.printForDebug(stderr);
+
     auto dictState = machine.getDict(config.getState()).getState();
     auto context = config.extractContext(5,5,machine.getDict(config.getState()));
     machine.getDict(config.getState()).setState(dictState);
diff --git a/decoder/src/macaon_decode.cpp b/decoder/src/macaon_decode.cpp
index 61b849a..f673c2c 100644
--- a/decoder/src/macaon_decode.cpp
+++ b/decoder/src/macaon_decode.cpp
@@ -22,6 +22,7 @@ po::options_description getOptionsDescription()
 
   po::options_description opt("Optional");
   opt.add_options()
+    ("debug,d", "Print debuging infos on stderr")
     ("help,h", "Produce this help message");
 
   desc.add(req).add(opt);
@@ -70,6 +71,7 @@ int main(int argc, char * argv[])
   auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : "";
   auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : "";
   auto mcdFile = variables["mcd"].as<std::string>();
+  bool debug = variables.count("debug") == 0 ? false : true;
 
   if (dictPaths.empty())
     util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultDictFilename, "")));
@@ -83,7 +85,7 @@ int main(int argc, char * argv[])
 
     BaseConfig config(mcdFile, inputTSV, inputTXT);
 
-    decoder.decode(config, 1);
+    decoder.decode(config, 1, debug);
 
     config.print(stdout);
   } catch(std::exception & e) {util::error(e);}
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index 45fccbe..69dde8d 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -26,8 +26,8 @@ class Trainer
   public :
 
   Trainer(ReadingMachine & machine);
-  void createDataset(SubConfig & goldConfig);
-  float epoch();
+  void createDataset(SubConfig & goldConfig, bool debug);
+  float epoch(bool printAdvancement);
 
 };
 
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 6496aa3..595c70b 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -5,7 +5,7 @@ Trainer::Trainer(ReadingMachine & machine) : machine(machine)
 {
 }
 
-void Trainer::createDataset(SubConfig & config)
+void Trainer::createDataset(SubConfig & config, bool debug)
 {
   config.addPredicted(machine.getPredicted());
   config.setState(machine.getStrategy().getInitialState());
@@ -15,6 +15,9 @@ void Trainer::createDataset(SubConfig & config)
 
   while (true)
   {
+    if (debug)
+      config.printForDebug(stderr);
+
     auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
     if (!transition)
     {
@@ -57,7 +60,7 @@ void Trainer::createDataset(SubConfig & config)
   sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5))); 
 }
 
-float Trainer::epoch()
+float Trainer::epoch(bool printAdvancement)
 {
   constexpr int printInterval = 2000;
   float totalLoss = 0.0;
@@ -83,14 +86,17 @@ float Trainer::epoch()
     denseOptimizer->step();
     sparseOptimizer->step();
 
-    nbExamplesUntilPrint -= labels.size(0);
-
-    ++currentBatchNumber;
-    if (nbExamplesUntilPrint <= 0)
+    if (printAdvancement)
     {
-      nbExamplesUntilPrint = printInterval;
-      fmt::print(stderr, "\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar);
-      lossSoFar = 0;
+      nbExamplesUntilPrint -= labels.size(0);
+
+      ++currentBatchNumber;
+      if (nbExamplesUntilPrint <= 0)
+      {
+        nbExamplesUntilPrint = printInterval;
+        fmt::print(stderr, "\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar);
+        lossSoFar = 0;
+      }
     }
   }
 
diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp
index bd3437c..62d09a6 100644
--- a/trainer/src/macaon_train.cpp
+++ b/trainer/src/macaon_train.cpp
@@ -21,6 +21,7 @@ po::options_description getOptionsDescription()
 
   po::options_description opt("Optional");
   opt.add_options()
+    ("debug,d", "Print debuging infos on stderr")
     ("trainTXT", po::value<std::string>()->default_value(""),
       "Raw text file of the training corpus")
     ("devTSV", po::value<std::string>()->default_value(""),
@@ -70,6 +71,7 @@ int main(int argc, char * argv[])
   auto devTsvFile = variables["devTSV"].as<std::string>();
   auto devRawFile = variables["devTXT"].as<std::string>();
   auto nbEpoch = variables["nbEpochs"].as<int>();
+  bool debug = variables.count("debug") == 0 ? false : true;
 
   ReadingMachine machine(machinePath.string());
 
@@ -77,7 +79,7 @@ int main(int argc, char * argv[])
   SubConfig config(goldConfig);
 
   Trainer trainer(machine);
-  trainer.createDataset(config);
+  trainer.createDataset(config, debug);
 
   Decoder decoder(machine);
   BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile);
@@ -86,10 +88,13 @@ int main(int argc, char * argv[])
 
   for (int i = 0; i < nbEpoch; i++)
   {
-    float loss = trainer.epoch();
+    float loss = trainer.epoch(!debug);
     auto devConfig = devGoldConfig;
-    fmt::print(stderr, "\r{:80}\rDecoding dev...", " ");
-    decoder.decode(devConfig, 1);
+    if (debug)
+      fmt::print(stderr, "Decoding dev :\n");
+    else
+      fmt::print(stderr, "\r{:80}\rDecoding dev...", " ");
+    decoder.decode(devConfig, 1, debug);
     decoder.evaluate(devConfig, modelPath, devTsvFile);
     float devScore = decoder.getF1Score("UPOS");
     bool saved = devScore > bestDevScore;
@@ -98,7 +103,10 @@ int main(int argc, char * argv[])
       bestDevScore = devScore;
       machine.save();
     }
-    fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : "");
+    if (debug)
+      fmt::print(stderr, "Epoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : "");
+    else
+      fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : "");
   }
 
   return 0;
-- 
GitLab