From f6de0f300a7b2141bb5d09763c37734078d58545 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 15 Apr 2020 18:52:00 +0200
Subject: [PATCH] Trainer now returns loss per example

---
 trainer/include/Trainer.hpp |  2 +-
 trainer/src/MacaonTrain.cpp |  4 ++--
 trainer/src/Trainer.cpp     | 29 +++++++++++++++--------------
 3 files changed, 18 insertions(+), 17 deletions(-)

diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index 03e7616..713cd4f 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -27,7 +27,7 @@ class Trainer
   private :
 
   void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
-  float processDataset(DataLoader & loader, bool train, bool printAdvancement);
+  float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples);
   void saveExamples(std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes, int & lastSavedIndex, int & currentExampleIndex, std::filesystem::path dir);
 
   public :
diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp
index 08f0684..1278b6d 100644
--- a/trainer/src/MacaonTrain.cpp
+++ b/trainer/src/MacaonTrain.cpp
@@ -189,7 +189,7 @@ int MacaonTrain::main()
       if (computeDevScore)
         devScoresStr += fmt::format("{}({:5.2f}{}),", score.second, score.first, computeDevScore ? "%" : "");
       else
-        devScoresStr += fmt::format("{}({:6.1f}{}),", score.second, score.first, computeDevScore ? "%" : "");
+        devScoresStr += fmt::format("{}({:6.4f}{}),", score.second, score.first, computeDevScore ? "%" : "");
       devScoreMean += score.first;
     }
     if (!devScoresStr.empty())
@@ -207,7 +207,7 @@ int MacaonTrain::main()
     trainer.saveOptimizer(optimizerCheckpoint);
     if (printAdvancement)
       fmt::print(stderr, "\r{:80}\r", "");
-    std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.1f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
+    std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.4f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
     fmt::print(stderr, "{}\n", iterStr);
     std::FILE * f = std::fopen(trainInfos.c_str(), "a");
     fmt::print(f, "{}\t{}\n", iterStr, devScoreMean);
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index efe4c2b..071abf6 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -131,11 +131,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
     auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
     gold[0] = goldIndex;
 
-    for (auto & element : context)
-    {
-      currentExampleIndex++;
-      classes.emplace_back(gold);
-    }
+    currentExampleIndex += context.size();
+    classes.insert(classes.end(), context.size(), gold);
 
     if (currentExampleIndex-lastSavedIndex >= maxNbExamplesPerFile)
       saveExamples(contexts, classes, lastSavedIndex, currentExampleIndex, dir);
@@ -169,13 +166,13 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
   fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(currentExampleIndex));
 }
 
-float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement)
+float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples)
 {
   constexpr int printInterval = 50;
   int nbExamplesProcessed = 0;
+  int totalNbExamplesProcessed = 0;
   float totalLoss = 0.0;
   float lossSoFar = 0.0;
-  int currentBatchNumber = 0;
 
   torch::AutoGradMode useGrad(train);
   machine.trainMode(train);
@@ -212,37 +209,41 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
       optimizer->step();
     }
 
+    totalNbExamplesProcessed += torch::numel(labels);
+
     if (printAdvancement)
     {
-      nbExamplesProcessed += labels.size(0);
+      nbExamplesProcessed += torch::numel(labels);
 
-      ++currentBatchNumber;
       if (nbExamplesProcessed >= printInterval)
       {
         auto actualTime = std::chrono::high_resolution_clock::now();
         double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
         pastTime = actualTime;
+        auto speed = (int)(nbExamplesProcessed/seconds);
+        auto progression = 100.0*totalNbExamplesProcessed / nbExamples;
+        auto statusStr = fmt::format("{:6.2f}% loss={:<7.3f} speed={:<6}ex/s", progression, lossSoFar, speed);
         if (train)
-          fmt::print(stderr, "\r{:80}\rcurrent epoch : {:6.2f}% loss={:<7.3f} speed={:<6}ex/s", "", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar, (int)(nbExamplesProcessed/seconds));
+          fmt::print(stderr, "\r{:80}\rtraining : {}", "", statusStr);
         else
-          fmt::print(stderr, "\r{:80}\reval on dev : loss={:<7.3f} speed={:<6}ex/s", "", lossSoFar, (int)(nbExamplesProcessed/seconds));
+          fmt::print(stderr, "\r{:80}\reval on dev : {}", "", statusStr);
         lossSoFar = 0;
         nbExamplesProcessed = 0;
       }
     }
   }
 
-  return totalLoss;
+  return totalLoss / nbExamples;
 }
 
 float Trainer::epoch(bool printAdvancement)
 {
-  return processDataset(dataLoader, true, printAdvancement);
+  return processDataset(dataLoader, true, printAdvancement, trainDataset->size().value());
 }
 
 float Trainer::evalOnDev(bool printAdvancement)
 {
-  return processDataset(devDataLoader, false, printAdvancement);
+  return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value());
 }
 
 void Trainer::loadOptimizer(std::filesystem::path path)
-- 
GitLab