From fb4e086976db793b776c875b8f40f6e01a665db8 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 24 Feb 2020 18:04:02 +0100
Subject: [PATCH] Print speed during training

---
 trainer/src/Trainer.cpp | 15 ++++++++++-----
 1 file changed, 10 insertions(+), 5 deletions(-)

diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index d51dd9c..64466b5 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -64,13 +64,15 @@ void Trainer::createDataset(SubConfig & config, bool debug)
 float Trainer::epoch(bool printAdvancement)
 {
   constexpr int printInterval = 2000;
+  int nbExamplesProcessed = 0;
   float totalLoss = 0.0;
   float lossSoFar = 0.0;
-  int nbExamplesUntilPrint = printInterval;
   int currentBatchNumber = 0;
 
   auto lossFct = torch::nn::CrossEntropyLoss();
 
+  auto pastTime = std::chrono::high_resolution_clock::now();
+
   for (auto & batch : *dataLoader)
   {
     optimizer->zero_grad();
@@ -92,14 +94,17 @@ float Trainer::epoch(bool printAdvancement)
 
     if (printAdvancement)
     {
-      nbExamplesUntilPrint -= labels.size(0);
+      nbExamplesProcessed += labels.size(0);
 
       ++currentBatchNumber;
-      if (nbExamplesUntilPrint <= 0)
+      if (nbExamplesProcessed >= printInterval)
       {
-        nbExamplesUntilPrint = printInterval;
-        fmt::print(stderr, "\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar);
+        auto actualTime = std::chrono::high_resolution_clock::now();
+        double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
+        pastTime = actualTime;
+        fmt::print(stderr, "\rcurrent epoch : {:6.2f}% loss={:<7.3f} speed={:<5}ex/s", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar, (int)(nbExamplesProcessed/seconds));
         lossSoFar = 0;
+        nbExamplesProcessed = 0;
       }
     }
   }
-- 
GitLab