From 3ee8137bc1baf95129781919aa7b1986083d92d2 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Thu, 13 Feb 2020 18:08:02 +0100 Subject: [PATCH] Added debug mode --- decoder/include/Decoder.hpp | 2 +- decoder/src/Decoder.cpp | 5 ++++- decoder/src/macaon_decode.cpp | 4 +++- trainer/include/Trainer.hpp | 4 ++-- trainer/src/Trainer.cpp | 24 +++++++++++++++--------- trainer/src/macaon_train.cpp | 18 +++++++++++++----- 6 files changed, 38 insertions(+), 19 deletions(-) diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp index 90f3f20..e8f7ecf 100644 --- a/decoder/include/Decoder.hpp +++ b/decoder/include/Decoder.hpp @@ -15,7 +15,7 @@ class Decoder public : Decoder(ReadingMachine & machine); - void decode(BaseConfig & config, std::size_t beamSize); + void decode(BaseConfig & config, std::size_t beamSize, bool debug); void evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV); float getMetricScore(const std::string & metric, std::size_t scoreIndex); float getPrecision(const std::string & metric); diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 543dbbf..78c3b86 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -5,7 +5,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) { } -void Decoder::decode(BaseConfig & config, std::size_t beamSize) +void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug) { config.addPredicted(machine.getPredicted()); @@ -15,6 +15,9 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize) while (true) { + if (debug) + config.printForDebug(stderr); + auto dictState = machine.getDict(config.getState()).getState(); auto context = config.extractContext(5,5,machine.getDict(config.getState())); machine.getDict(config.getState()).setState(dictState); diff --git a/decoder/src/macaon_decode.cpp b/decoder/src/macaon_decode.cpp index 61b849a..f673c2c 100644 --- a/decoder/src/macaon_decode.cpp +++ b/decoder/src/macaon_decode.cpp @@ -22,6 +22,7 @@ po::options_description getOptionsDescription() po::options_description opt("Optional"); opt.add_options() + ("debug,d", "Print debuging infos on stderr") ("help,h", "Produce this help message"); desc.add(req).add(opt); @@ -70,6 +71,7 @@ int main(int argc, char * argv[]) auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : ""; auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : ""; auto mcdFile = variables["mcd"].as<std::string>(); + bool debug = variables.count("debug") == 0 ? false : true; if (dictPaths.empty()) util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultDictFilename, ""))); @@ -83,7 +85,7 @@ int main(int argc, char * argv[]) BaseConfig config(mcdFile, inputTSV, inputTXT); - decoder.decode(config, 1); + decoder.decode(config, 1, debug); config.print(stdout); } catch(std::exception & e) {util::error(e);} diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 45fccbe..69dde8d 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -26,8 +26,8 @@ class Trainer public : Trainer(ReadingMachine & machine); - void createDataset(SubConfig & goldConfig); - float epoch(); + void createDataset(SubConfig & goldConfig, bool debug); + float epoch(bool printAdvancement); }; diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 6496aa3..595c70b 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -5,7 +5,7 @@ Trainer::Trainer(ReadingMachine & machine) : machine(machine) { } -void Trainer::createDataset(SubConfig & config) +void Trainer::createDataset(SubConfig & config, bool debug) { config.addPredicted(machine.getPredicted()); config.setState(machine.getStrategy().getInitialState()); @@ -15,6 +15,9 @@ void Trainer::createDataset(SubConfig & config) while (true) { + if (debug) + config.printForDebug(stderr); + auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); if (!transition) { @@ -57,7 +60,7 @@ void Trainer::createDataset(SubConfig & config) sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5))); } -float Trainer::epoch() +float Trainer::epoch(bool printAdvancement) { constexpr int printInterval = 2000; float totalLoss = 0.0; @@ -83,14 +86,17 @@ float Trainer::epoch() denseOptimizer->step(); sparseOptimizer->step(); - nbExamplesUntilPrint -= labels.size(0); - - ++currentBatchNumber; - if (nbExamplesUntilPrint <= 0) + if (printAdvancement) { - nbExamplesUntilPrint = printInterval; - fmt::print(stderr, "\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar); - lossSoFar = 0; + nbExamplesUntilPrint -= labels.size(0); + + ++currentBatchNumber; + if (nbExamplesUntilPrint <= 0) + { + nbExamplesUntilPrint = printInterval; + fmt::print(stderr, "\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar); + lossSoFar = 0; + } } } diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp index bd3437c..62d09a6 100644 --- a/trainer/src/macaon_train.cpp +++ b/trainer/src/macaon_train.cpp @@ -21,6 +21,7 @@ po::options_description getOptionsDescription() po::options_description opt("Optional"); opt.add_options() + ("debug,d", "Print debuging infos on stderr") ("trainTXT", po::value<std::string>()->default_value(""), "Raw text file of the training corpus") ("devTSV", po::value<std::string>()->default_value(""), @@ -70,6 +71,7 @@ int main(int argc, char * argv[]) auto devTsvFile = variables["devTSV"].as<std::string>(); auto devRawFile = variables["devTXT"].as<std::string>(); auto nbEpoch = variables["nbEpochs"].as<int>(); + bool debug = variables.count("debug") == 0 ? false : true; ReadingMachine machine(machinePath.string()); @@ -77,7 +79,7 @@ int main(int argc, char * argv[]) SubConfig config(goldConfig); Trainer trainer(machine); - trainer.createDataset(config); + trainer.createDataset(config, debug); Decoder decoder(machine); BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile); @@ -86,10 +88,13 @@ int main(int argc, char * argv[]) for (int i = 0; i < nbEpoch; i++) { - float loss = trainer.epoch(); + float loss = trainer.epoch(!debug); auto devConfig = devGoldConfig; - fmt::print(stderr, "\r{:80}\rDecoding dev...", " "); - decoder.decode(devConfig, 1); + if (debug) + fmt::print(stderr, "Decoding dev :\n"); + else + fmt::print(stderr, "\r{:80}\rDecoding dev...", " "); + decoder.decode(devConfig, 1, debug); decoder.evaluate(devConfig, modelPath, devTsvFile); float devScore = decoder.getF1Score("UPOS"); bool saved = devScore > bestDevScore; @@ -98,7 +103,10 @@ int main(int argc, char * argv[]) bestDevScore = devScore; machine.save(); } - fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : ""); + if (debug) + fmt::print(stderr, "Epoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : ""); + else + fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : ""); } return 0; -- GitLab