diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp index 4eabde1a2cf4fe8e664494fd974e42952ad2dfa2..b6f0eb3a83990ce479e59615306550434ce8ff22 100644 --- a/decoder/include/Decoder.hpp +++ b/decoder/include/Decoder.hpp @@ -25,7 +25,7 @@ class Decoder public : Decoder(ReadingMachine & machine); - void decode(BaseConfig & config, std::size_t beamSize, bool debug); + void decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement); void evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV); std::vector<std::pair<float,std::string>> getF1Scores(const std::set<std::string> & colNames) const; std::vector<std::pair<float,std::string>> getAlignedAccs(const std::set<std::string> & colNames) const; diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 60fbf2acb20760a6677f01f891ba27f8ea308b16..147396e0bd5eccb60c4740bca66f58eedc1a56a5 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -5,11 +5,15 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) { } -void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug) +void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement) { machine.getClassifier()->getNN()->train(false); config.addPredicted(machine.getPredicted()); + constexpr int printInterval = 50; + int nbExamplesProcessed = 0; + auto pastTime = std::chrono::high_resolution_clock::now(); + try { config.setState(machine.getStrategy().getInitialState()); @@ -53,6 +57,16 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug) transition->apply(config); config.addToHistory(transition->getName()); + if (printAdvancement) + if (++nbExamplesProcessed >= printInterval) + { + auto actualTime = std::chrono::high_resolution_clock::now(); + double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0; + pastTime = actualTime; + fmt::print(stderr, "\rdecoding... speed={:<5}ex/s\r", (int)(nbExamplesProcessed/seconds)); + nbExamplesProcessed = 0; + } + auto movement = machine.getStrategy().getMovement(config, transition->getName()); if (debug) fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second); diff --git a/decoder/src/macaon_decode.cpp b/decoder/src/macaon_decode.cpp index 5b7272b0e7d1fb4aa150eca63eb1287902a633a8..d5dda544924b1642663865bf09438e2afa3eb488 100644 --- a/decoder/src/macaon_decode.cpp +++ b/decoder/src/macaon_decode.cpp @@ -23,6 +23,7 @@ po::options_description getOptionsDescription() po::options_description opt("Optional"); opt.add_options() ("debug,d", "Print debuging infos on stderr") + ("silent", "Don't print speed and progress") ("help,h", "Produce this help message"); desc.add(req).add(opt); @@ -72,6 +73,7 @@ int main(int argc, char * argv[]) auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : ""; auto mcdFile = variables["mcd"].as<std::string>(); bool debug = variables.count("debug") == 0 ? false : true; + bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; if (dictPaths.empty()) util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultDictFilename, ""))); @@ -87,7 +89,7 @@ int main(int argc, char * argv[]) BaseConfig config(mcdFile, inputTSV, inputTXT); - decoder.decode(config, 1, debug); + decoder.decode(config, 1, debug, printAdvancement); config.print(stdout); } catch(std::exception & e) {util::error(e);} diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp index 70aee20d3e45bd0b6423b023e2a83a44607cd002..dc48e459b3357c20e07ad772acc7af839ea262b5 100644 --- a/trainer/src/macaon_train.cpp +++ b/trainer/src/macaon_train.cpp @@ -23,6 +23,7 @@ po::options_description getOptionsDescription() po::options_description opt("Optional"); opt.add_options() ("debug,d", "Print debuging infos on stderr") + ("silent", "Don't print speed and progress") ("trainTXT", po::value<std::string>()->default_value(""), "Raw text file of the training corpus") ("devTSV", po::value<std::string>()->default_value(""), @@ -73,6 +74,7 @@ int main(int argc, char * argv[]) auto devRawFile = variables["devTXT"].as<std::string>(); auto nbEpoch = variables["nbEpochs"].as<int>(); bool debug = variables.count("debug") == 0 ? false : true; + bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; fmt::print(stderr, "Training using device : {}\n", NeuralNetworkImpl::device.str()); @@ -94,14 +96,12 @@ int main(int argc, char * argv[]) for (int i = 0; i < nbEpoch; i++) { - float loss = trainer.epoch(!debug); + float loss = trainer.epoch(printAdvancement); machine.getStrategy().reset(); auto devConfig = devGoldConfig; if (debug) fmt::print(stderr, "Decoding dev :\n"); - else - fmt::print(stderr, "\r{:80}\rDecoding dev...", " "); - decoder.decode(devConfig, 1, debug); + decoder.decode(devConfig, 1, debug, printAdvancement); machine.getStrategy().reset(); decoder.evaluate(devConfig, modelPath, devTsvFile); std::vector<std::pair<float,std::string>> devScores = decoder.getF1Scores(machine.getPredicted());