Skip to content
Snippets Groups Projects
Commit fb4e0869 authored by Franck Dary's avatar Franck Dary
Browse files

Print speed during training

parent a7558853
No related branches found
No related tags found
No related merge requests found
......@@ -64,13 +64,15 @@ void Trainer::createDataset(SubConfig & config, bool debug)
float Trainer::epoch(bool printAdvancement)
{
constexpr int printInterval = 2000;
int nbExamplesProcessed = 0;
float totalLoss = 0.0;
float lossSoFar = 0.0;
int nbExamplesUntilPrint = printInterval;
int currentBatchNumber = 0;
auto lossFct = torch::nn::CrossEntropyLoss();
auto pastTime = std::chrono::high_resolution_clock::now();
for (auto & batch : *dataLoader)
{
optimizer->zero_grad();
......@@ -92,14 +94,17 @@ float Trainer::epoch(bool printAdvancement)
if (printAdvancement)
{
nbExamplesUntilPrint -= labels.size(0);
nbExamplesProcessed += labels.size(0);
++currentBatchNumber;
if (nbExamplesUntilPrint <= 0)
if (nbExamplesProcessed >= printInterval)
{
nbExamplesUntilPrint = printInterval;
fmt::print(stderr, "\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar);
auto actualTime = std::chrono::high_resolution_clock::now();
double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
pastTime = actualTime;
fmt::print(stderr, "\rcurrent epoch : {:6.2f}% loss={:<7.3f} speed={:<5}ex/s", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar, (int)(nbExamplesProcessed/seconds));
lossSoFar = 0;
nbExamplesProcessed = 0;
}
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment