From 8b6ba7d90ed6370b8f284b7307f941998a974530 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 6 May 2019 13:43:50 +0200 Subject: [PATCH] Added program option randomDebug, where debug infos are only printed with a certain probability in order to speed things up --- maca_common/include/ProgramParameters.hpp | 2 ++ maca_common/src/Dict.cpp | 3 ++- maca_common/src/ProgramParameters.cpp | 2 ++ trainer/include/Trainer.hpp | 2 ++ trainer/src/TrainInfos.cpp | 2 +- trainer/src/Trainer.cpp | 13 +++++++++++++ trainer/src/macaon_train.cpp | 5 +++++ 7 files changed, 27 insertions(+), 2 deletions(-) diff --git a/maca_common/include/ProgramParameters.hpp b/maca_common/include/ProgramParameters.hpp index a149a50..e7ab85c 100644 --- a/maca_common/include/ProgramParameters.hpp +++ b/maca_common/include/ProgramParameters.hpp @@ -73,6 +73,8 @@ struct ProgramParameters static bool featureExtraction; static bool devEvalOnGold; static bool devLoss; + static bool randomDebug; + static float randomDebugProbability; private : diff --git a/maca_common/src/Dict.cpp b/maca_common/src/Dict.cpp index b28df18..03f8753 100644 --- a/maca_common/src/Dict.cpp +++ b/maca_common/src/Dict.cpp @@ -375,7 +375,8 @@ unsigned int Dict::addEntry(const std::string & s) if ((int)str2index.size() >= ProgramParameters::dictCapacity) { - fprintf(stderr, "ERROR (%s) : Dict %s of maximal capacity %d is full. Aborting.\n", ERRINFO, name.c_str(), ProgramParameters::dictCapacity); + fprintf(stderr, "ERROR (%s) : Dict %s of maximal capacity %d is full. Saving dict than aborting.\n", ERRINFO, name.c_str(), ProgramParameters::dictCapacity); + save(); exit(1); } diff --git a/maca_common/src/ProgramParameters.cpp b/maca_common/src/ProgramParameters.cpp index 07be5b4..13a6629 100644 --- a/maca_common/src/ProgramParameters.cpp +++ b/maca_common/src/ProgramParameters.cpp @@ -67,4 +67,6 @@ float ProgramParameters::maskRate; bool ProgramParameters::featureExtraction; bool ProgramParameters::devEvalOnGold; bool ProgramParameters::devLoss; +bool ProgramParameters::randomDebug; +float ProgramParameters::randomDebugProbability; diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 7568a73..1294283 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -83,6 +83,8 @@ class Trainer void doStepTrain(); /// @brief Compute and print dev scores, increase epoch counter. void prepareNextEpoch(); + /// @brief Set the debug variable ProgramParameters::debug. + void setDebugValue(); public : diff --git a/trainer/src/TrainInfos.cpp b/trainer/src/TrainInfos.cpp index d754cd6..bc893e3 100644 --- a/trainer/src/TrainInfos.cpp +++ b/trainer/src/TrainInfos.cpp @@ -281,7 +281,7 @@ void TrainInfos::printScores(FILE * output) } if (ProgramParameters::interactive) - fprintf(stderr, " \r"); + fprintf(output, " \r"); if (ProgramParameters::printTime) fprintf(output, "[%s] ", getTime().c_str()); fprintf(output, "Iteration %d/%d : \n", getEpoch(), ProgramParameters::nbIter); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 6076320..f33d538 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -23,6 +23,19 @@ Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, C pastTime = std::chrono::high_resolution_clock::now(); } +void Trainer::setDebugValue() +{ + if (!ProgramParameters::randomDebug) + return; + + if (ProgramParameters::interactive) + fprintf(stderr, " \r"); + if (ProgramParameters::printTime) + fprintf(stderr, "[%s] :\n", getTime().c_str()); + + ProgramParameters::debug = choiceWithProbability(ProgramParameters::randomDebugProbability); +} + void Trainer::computeScoreOnDev() { if (!devConfig) diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp index 7a9a2b8..ede41e2 100644 --- a/trainer/src/macaon_train.cpp +++ b/trainer/src/macaon_train.cpp @@ -40,6 +40,9 @@ po::options_description getOptionsDescription() opt.add_options() ("help,h", "Produce this help message") ("debug,d", "Print infos on stderr") + ("randomDebug", "Print infos on stderr with a probability of randomDebugProbability") + ("randomDebugProbability", po::value<float>()->default_value(0.001), + "Probability that debug infos will be printed") ("printEntropy", "Print mean entropy and standard deviation accross sequences") ("dicts", po::value<std::string>()->default_value(""), "The .dict file describing all the dictionaries to be used in the experiement. By default the filename specified in the .tm file will be used") @@ -270,6 +273,7 @@ int main(int argc, char * argv[]) ProgramParameters::bdName = vm["bd"].as<std::string>(); ProgramParameters::mcdName = vm["mcd"].as<std::string>(); ProgramParameters::debug = vm.count("debug") == 0 ? false : true; + ProgramParameters::randomDebug = vm.count("randomDebug") == 0 ? false : true; ProgramParameters::printEntropy = vm.count("printEntropy") == 0 ? false : true; ProgramParameters::printTime = vm.count("printTime") == 0 ? false : true; ProgramParameters::featureExtraction = vm.count("featureExtraction") == 0 ? false : true; @@ -291,6 +295,7 @@ int main(int argc, char * argv[]) ProgramParameters::sequenceDelimiterTape = vm["sequenceDelimiterTape"].as<std::string>(); ProgramParameters::sequenceDelimiter = vm["sequenceDelimiter"].as<std::string>(); ProgramParameters::learningRate = vm["lr"].as<float>(); + ProgramParameters::randomDebugProbability = vm["randomDebugProbability"].as<float>(); ProgramParameters::beta1 = vm["b1"].as<float>(); ProgramParameters::beta2 = vm["b2"].as<float>(); ProgramParameters::bias = vm["bias"].as<float>(); -- GitLab