From f6de0f300a7b2141bb5d09763c37734078d58545 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 15 Apr 2020 18:52:00 +0200 Subject: [PATCH] Trainer now returns loss per example --- trainer/include/Trainer.hpp | 2 +- trainer/src/MacaonTrain.cpp | 4 ++-- trainer/src/Trainer.cpp | 29 +++++++++++++++-------------- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 03e7616..713cd4f 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -27,7 +27,7 @@ class Trainer private : void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); - float processDataset(DataLoader & loader, bool train, bool printAdvancement); + float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples); void saveExamples(std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes, int & lastSavedIndex, int & currentExampleIndex, std::filesystem::path dir); public : diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 08f0684..1278b6d 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -189,7 +189,7 @@ int MacaonTrain::main() if (computeDevScore) devScoresStr += fmt::format("{}({:5.2f}{}),", score.second, score.first, computeDevScore ? "%" : ""); else - devScoresStr += fmt::format("{}({:6.1f}{}),", score.second, score.first, computeDevScore ? "%" : ""); + devScoresStr += fmt::format("{}({:6.4f}{}),", score.second, score.first, computeDevScore ? "%" : ""); devScoreMean += score.first; } if (!devScoresStr.empty()) @@ -207,7 +207,7 @@ int MacaonTrain::main() trainer.saveOptimizer(optimizerCheckpoint); if (printAdvancement) fmt::print(stderr, "\r{:80}\r", ""); - std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.1f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : ""); + std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.4f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : ""); fmt::print(stderr, "{}\n", iterStr); std::FILE * f = std::fopen(trainInfos.c_str(), "a"); fmt::print(f, "{}\t{}\n", iterStr, devScoreMean); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index efe4c2b..071abf6 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -131,11 +131,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong)); gold[0] = goldIndex; - for (auto & element : context) - { - currentExampleIndex++; - classes.emplace_back(gold); - } + currentExampleIndex += context.size(); + classes.insert(classes.end(), context.size(), gold); if (currentExampleIndex-lastSavedIndex >= maxNbExamplesPerFile) saveExamples(contexts, classes, lastSavedIndex, currentExampleIndex, dir); @@ -169,13 +166,13 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(currentExampleIndex)); } -float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement) +float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples) { constexpr int printInterval = 50; int nbExamplesProcessed = 0; + int totalNbExamplesProcessed = 0; float totalLoss = 0.0; float lossSoFar = 0.0; - int currentBatchNumber = 0; torch::AutoGradMode useGrad(train); machine.trainMode(train); @@ -212,37 +209,41 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance optimizer->step(); } + totalNbExamplesProcessed += torch::numel(labels); + if (printAdvancement) { - nbExamplesProcessed += labels.size(0); + nbExamplesProcessed += torch::numel(labels); - ++currentBatchNumber; if (nbExamplesProcessed >= printInterval) { auto actualTime = std::chrono::high_resolution_clock::now(); double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0; pastTime = actualTime; + auto speed = (int)(nbExamplesProcessed/seconds); + auto progression = 100.0*totalNbExamplesProcessed / nbExamples; + auto statusStr = fmt::format("{:6.2f}% loss={:<7.3f} speed={:<6}ex/s", progression, lossSoFar, speed); if (train) - fmt::print(stderr, "\r{:80}\rcurrent epoch : {:6.2f}% loss={:<7.3f} speed={:<6}ex/s", "", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar, (int)(nbExamplesProcessed/seconds)); + fmt::print(stderr, "\r{:80}\rtraining : {}", "", statusStr); else - fmt::print(stderr, "\r{:80}\reval on dev : loss={:<7.3f} speed={:<6}ex/s", "", lossSoFar, (int)(nbExamplesProcessed/seconds)); + fmt::print(stderr, "\r{:80}\reval on dev : {}", "", statusStr); lossSoFar = 0; nbExamplesProcessed = 0; } } } - return totalLoss; + return totalLoss / nbExamples; } float Trainer::epoch(bool printAdvancement) { - return processDataset(dataLoader, true, printAdvancement); + return processDataset(dataLoader, true, printAdvancement, trainDataset->size().value()); } float Trainer::evalOnDev(bool printAdvancement) { - return processDataset(devDataLoader, false, printAdvancement); + return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value()); } void Trainer::loadOptimizer(std::filesystem::path path) -- GitLab