Skip to content
Snippets Groups Projects
Commit a4fe4f0b authored by Franck Dary's avatar Franck Dary
Browse files

Printing decoding speed

parent 764c2fc7
No related branches found
No related tags found
No related merge requests found
......@@ -25,7 +25,7 @@ class Decoder
public :
Decoder(ReadingMachine & machine);
void decode(BaseConfig & config, std::size_t beamSize, bool debug);
void decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement);
void evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV);
std::vector<std::pair<float,std::string>> getF1Scores(const std::set<std::string> & colNames) const;
std::vector<std::pair<float,std::string>> getAlignedAccs(const std::set<std::string> & colNames) const;
......
......@@ -5,11 +5,15 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
{
}
void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement)
{
machine.getClassifier()->getNN()->train(false);
config.addPredicted(machine.getPredicted());
constexpr int printInterval = 50;
int nbExamplesProcessed = 0;
auto pastTime = std::chrono::high_resolution_clock::now();
try
{
config.setState(machine.getStrategy().getInitialState());
......@@ -53,6 +57,16 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
transition->apply(config);
config.addToHistory(transition->getName());
if (printAdvancement)
if (++nbExamplesProcessed >= printInterval)
{
auto actualTime = std::chrono::high_resolution_clock::now();
double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0;
pastTime = actualTime;
fmt::print(stderr, "\rdecoding... speed={:<5}ex/s\r", (int)(nbExamplesProcessed/seconds));
nbExamplesProcessed = 0;
}
auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (debug)
fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
......
......@@ -23,6 +23,7 @@ po::options_description getOptionsDescription()
po::options_description opt("Optional");
opt.add_options()
("debug,d", "Print debuging infos on stderr")
("silent", "Don't print speed and progress")
("help,h", "Produce this help message");
desc.add(req).add(opt);
......@@ -72,6 +73,7 @@ int main(int argc, char * argv[])
auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : "";
auto mcdFile = variables["mcd"].as<std::string>();
bool debug = variables.count("debug") == 0 ? false : true;
bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
if (dictPaths.empty())
util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultDictFilename, "")));
......@@ -87,7 +89,7 @@ int main(int argc, char * argv[])
BaseConfig config(mcdFile, inputTSV, inputTXT);
decoder.decode(config, 1, debug);
decoder.decode(config, 1, debug, printAdvancement);
config.print(stdout);
} catch(std::exception & e) {util::error(e);}
......
......@@ -23,6 +23,7 @@ po::options_description getOptionsDescription()
po::options_description opt("Optional");
opt.add_options()
("debug,d", "Print debuging infos on stderr")
("silent", "Don't print speed and progress")
("trainTXT", po::value<std::string>()->default_value(""),
"Raw text file of the training corpus")
("devTSV", po::value<std::string>()->default_value(""),
......@@ -73,6 +74,7 @@ int main(int argc, char * argv[])
auto devRawFile = variables["devTXT"].as<std::string>();
auto nbEpoch = variables["nbEpochs"].as<int>();
bool debug = variables.count("debug") == 0 ? false : true;
bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
fmt::print(stderr, "Training using device : {}\n", NeuralNetworkImpl::device.str());
......@@ -94,14 +96,12 @@ int main(int argc, char * argv[])
for (int i = 0; i < nbEpoch; i++)
{
float loss = trainer.epoch(!debug);
float loss = trainer.epoch(printAdvancement);
machine.getStrategy().reset();
auto devConfig = devGoldConfig;
if (debug)
fmt::print(stderr, "Decoding dev :\n");
else
fmt::print(stderr, "\r{:80}\rDecoding dev...", " ");
decoder.decode(devConfig, 1, debug);
decoder.decode(devConfig, 1, debug, printAdvancement);
machine.getStrategy().reset();
decoder.evaluate(devConfig, modelPath, devTsvFile);
std::vector<std::pair<float,std::string>> devScores = decoder.getF1Scores(machine.getPredicted());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment