Skip to content
Snippets Groups Projects
Commit 3ee8137b authored by Franck Dary's avatar Franck Dary
Browse files

Added debug mode

parent dd050cbc
Branches
No related tags found
No related merge requests found
...@@ -15,7 +15,7 @@ class Decoder ...@@ -15,7 +15,7 @@ class Decoder
public : public :
Decoder(ReadingMachine & machine); Decoder(ReadingMachine & machine);
void decode(BaseConfig & config, std::size_t beamSize); void decode(BaseConfig & config, std::size_t beamSize, bool debug);
void evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV); void evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV);
float getMetricScore(const std::string & metric, std::size_t scoreIndex); float getMetricScore(const std::string & metric, std::size_t scoreIndex);
float getPrecision(const std::string & metric); float getPrecision(const std::string & metric);
......
...@@ -5,7 +5,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) ...@@ -5,7 +5,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
{ {
} }
void Decoder::decode(BaseConfig & config, std::size_t beamSize) void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
{ {
config.addPredicted(machine.getPredicted()); config.addPredicted(machine.getPredicted());
...@@ -15,6 +15,9 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize) ...@@ -15,6 +15,9 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize)
while (true) while (true)
{ {
if (debug)
config.printForDebug(stderr);
auto dictState = machine.getDict(config.getState()).getState(); auto dictState = machine.getDict(config.getState()).getState();
auto context = config.extractContext(5,5,machine.getDict(config.getState())); auto context = config.extractContext(5,5,machine.getDict(config.getState()));
machine.getDict(config.getState()).setState(dictState); machine.getDict(config.getState()).setState(dictState);
......
...@@ -22,6 +22,7 @@ po::options_description getOptionsDescription() ...@@ -22,6 +22,7 @@ po::options_description getOptionsDescription()
po::options_description opt("Optional"); po::options_description opt("Optional");
opt.add_options() opt.add_options()
("debug,d", "Print debuging infos on stderr")
("help,h", "Produce this help message"); ("help,h", "Produce this help message");
desc.add(req).add(opt); desc.add(req).add(opt);
...@@ -70,6 +71,7 @@ int main(int argc, char * argv[]) ...@@ -70,6 +71,7 @@ int main(int argc, char * argv[])
auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : ""; auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : "";
auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : ""; auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : "";
auto mcdFile = variables["mcd"].as<std::string>(); auto mcdFile = variables["mcd"].as<std::string>();
bool debug = variables.count("debug") == 0 ? false : true;
if (dictPaths.empty()) if (dictPaths.empty())
util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultDictFilename, ""))); util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultDictFilename, "")));
...@@ -83,7 +85,7 @@ int main(int argc, char * argv[]) ...@@ -83,7 +85,7 @@ int main(int argc, char * argv[])
BaseConfig config(mcdFile, inputTSV, inputTXT); BaseConfig config(mcdFile, inputTSV, inputTXT);
decoder.decode(config, 1); decoder.decode(config, 1, debug);
config.print(stdout); config.print(stdout);
} catch(std::exception & e) {util::error(e);} } catch(std::exception & e) {util::error(e);}
......
...@@ -26,8 +26,8 @@ class Trainer ...@@ -26,8 +26,8 @@ class Trainer
public : public :
Trainer(ReadingMachine & machine); Trainer(ReadingMachine & machine);
void createDataset(SubConfig & goldConfig); void createDataset(SubConfig & goldConfig, bool debug);
float epoch(); float epoch(bool printAdvancement);
}; };
......
...@@ -5,7 +5,7 @@ Trainer::Trainer(ReadingMachine & machine) : machine(machine) ...@@ -5,7 +5,7 @@ Trainer::Trainer(ReadingMachine & machine) : machine(machine)
{ {
} }
void Trainer::createDataset(SubConfig & config) void Trainer::createDataset(SubConfig & config, bool debug)
{ {
config.addPredicted(machine.getPredicted()); config.addPredicted(machine.getPredicted());
config.setState(machine.getStrategy().getInitialState()); config.setState(machine.getStrategy().getInitialState());
...@@ -15,6 +15,9 @@ void Trainer::createDataset(SubConfig & config) ...@@ -15,6 +15,9 @@ void Trainer::createDataset(SubConfig & config)
while (true) while (true)
{ {
if (debug)
config.printForDebug(stderr);
auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!transition) if (!transition)
{ {
...@@ -57,7 +60,7 @@ void Trainer::createDataset(SubConfig & config) ...@@ -57,7 +60,7 @@ void Trainer::createDataset(SubConfig & config)
sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5))); sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5)));
} }
float Trainer::epoch() float Trainer::epoch(bool printAdvancement)
{ {
constexpr int printInterval = 2000; constexpr int printInterval = 2000;
float totalLoss = 0.0; float totalLoss = 0.0;
...@@ -83,6 +86,8 @@ float Trainer::epoch() ...@@ -83,6 +86,8 @@ float Trainer::epoch()
denseOptimizer->step(); denseOptimizer->step();
sparseOptimizer->step(); sparseOptimizer->step();
if (printAdvancement)
{
nbExamplesUntilPrint -= labels.size(0); nbExamplesUntilPrint -= labels.size(0);
++currentBatchNumber; ++currentBatchNumber;
...@@ -93,6 +98,7 @@ float Trainer::epoch() ...@@ -93,6 +98,7 @@ float Trainer::epoch()
lossSoFar = 0; lossSoFar = 0;
} }
} }
}
return totalLoss; return totalLoss;
} }
......
...@@ -21,6 +21,7 @@ po::options_description getOptionsDescription() ...@@ -21,6 +21,7 @@ po::options_description getOptionsDescription()
po::options_description opt("Optional"); po::options_description opt("Optional");
opt.add_options() opt.add_options()
("debug,d", "Print debuging infos on stderr")
("trainTXT", po::value<std::string>()->default_value(""), ("trainTXT", po::value<std::string>()->default_value(""),
"Raw text file of the training corpus") "Raw text file of the training corpus")
("devTSV", po::value<std::string>()->default_value(""), ("devTSV", po::value<std::string>()->default_value(""),
...@@ -70,6 +71,7 @@ int main(int argc, char * argv[]) ...@@ -70,6 +71,7 @@ int main(int argc, char * argv[])
auto devTsvFile = variables["devTSV"].as<std::string>(); auto devTsvFile = variables["devTSV"].as<std::string>();
auto devRawFile = variables["devTXT"].as<std::string>(); auto devRawFile = variables["devTXT"].as<std::string>();
auto nbEpoch = variables["nbEpochs"].as<int>(); auto nbEpoch = variables["nbEpochs"].as<int>();
bool debug = variables.count("debug") == 0 ? false : true;
ReadingMachine machine(machinePath.string()); ReadingMachine machine(machinePath.string());
...@@ -77,7 +79,7 @@ int main(int argc, char * argv[]) ...@@ -77,7 +79,7 @@ int main(int argc, char * argv[])
SubConfig config(goldConfig); SubConfig config(goldConfig);
Trainer trainer(machine); Trainer trainer(machine);
trainer.createDataset(config); trainer.createDataset(config, debug);
Decoder decoder(machine); Decoder decoder(machine);
BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile); BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile);
...@@ -86,10 +88,13 @@ int main(int argc, char * argv[]) ...@@ -86,10 +88,13 @@ int main(int argc, char * argv[])
for (int i = 0; i < nbEpoch; i++) for (int i = 0; i < nbEpoch; i++)
{ {
float loss = trainer.epoch(); float loss = trainer.epoch(!debug);
auto devConfig = devGoldConfig; auto devConfig = devGoldConfig;
if (debug)
fmt::print(stderr, "Decoding dev :\n");
else
fmt::print(stderr, "\r{:80}\rDecoding dev...", " "); fmt::print(stderr, "\r{:80}\rDecoding dev...", " ");
decoder.decode(devConfig, 1); decoder.decode(devConfig, 1, debug);
decoder.evaluate(devConfig, modelPath, devTsvFile); decoder.evaluate(devConfig, modelPath, devTsvFile);
float devScore = decoder.getF1Score("UPOS"); float devScore = decoder.getF1Score("UPOS");
bool saved = devScore > bestDevScore; bool saved = devScore > bestDevScore;
...@@ -98,6 +103,9 @@ int main(int argc, char * argv[]) ...@@ -98,6 +103,9 @@ int main(int argc, char * argv[])
bestDevScore = devScore; bestDevScore = devScore;
machine.save(); machine.save();
} }
if (debug)
fmt::print(stderr, "Epoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : "");
else
fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : ""); fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : "");
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment