From fb4e086976db793b776c875b8f40f6e01a665db8 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 24 Feb 2020 18:04:02 +0100 Subject: [PATCH] Print speed during training --- trainer/src/Trainer.cpp | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index d51dd9c..64466b5 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -64,13 +64,15 @@ void Trainer::createDataset(SubConfig & config, bool debug) float Trainer::epoch(bool printAdvancement) { constexpr int printInterval = 2000; + int nbExamplesProcessed = 0; float totalLoss = 0.0; float lossSoFar = 0.0; - int nbExamplesUntilPrint = printInterval; int currentBatchNumber = 0; auto lossFct = torch::nn::CrossEntropyLoss(); + auto pastTime = std::chrono::high_resolution_clock::now(); + for (auto & batch : *dataLoader) { optimizer->zero_grad(); @@ -92,14 +94,17 @@ float Trainer::epoch(bool printAdvancement) if (printAdvancement) { - nbExamplesUntilPrint -= labels.size(0); + nbExamplesProcessed += labels.size(0); ++currentBatchNumber; - if (nbExamplesUntilPrint <= 0) + if (nbExamplesProcessed >= printInterval) { - nbExamplesUntilPrint = printInterval; - fmt::print(stderr, "\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar); + auto actualTime = std::chrono::high_resolution_clock::now(); + double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0; + pastTime = actualTime; + fmt::print(stderr, "\rcurrent epoch : {:6.2f}% loss={:<7.3f} speed={:<5}ex/s", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar, (int)(nbExamplesProcessed/seconds)); lossSoFar = 0; + nbExamplesProcessed = 0; } } } -- GitLab