Skip to content
Snippets Groups Projects
Commit 3ee8137b authored by Franck Dary's avatar Franck Dary
Browse files

Added debug mode

parent dd050cbc
No related branches found
No related tags found
No related merge requests found
......@@ -15,7 +15,7 @@ class Decoder
public :
Decoder(ReadingMachine & machine);
void decode(BaseConfig & config, std::size_t beamSize);
void decode(BaseConfig & config, std::size_t beamSize, bool debug);
void evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV);
float getMetricScore(const std::string & metric, std::size_t scoreIndex);
float getPrecision(const std::string & metric);
......
......@@ -5,7 +5,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
{
}
void Decoder::decode(BaseConfig & config, std::size_t beamSize)
void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug)
{
config.addPredicted(machine.getPredicted());
......@@ -15,6 +15,9 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize)
while (true)
{
if (debug)
config.printForDebug(stderr);
auto dictState = machine.getDict(config.getState()).getState();
auto context = config.extractContext(5,5,machine.getDict(config.getState()));
machine.getDict(config.getState()).setState(dictState);
......
......@@ -22,6 +22,7 @@ po::options_description getOptionsDescription()
po::options_description opt("Optional");
opt.add_options()
("debug,d", "Print debuging infos on stderr")
("help,h", "Produce this help message");
desc.add(req).add(opt);
......@@ -70,6 +71,7 @@ int main(int argc, char * argv[])
auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : "";
auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : "";
auto mcdFile = variables["mcd"].as<std::string>();
bool debug = variables.count("debug") == 0 ? false : true;
if (dictPaths.empty())
util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultDictFilename, "")));
......@@ -83,7 +85,7 @@ int main(int argc, char * argv[])
BaseConfig config(mcdFile, inputTSV, inputTXT);
decoder.decode(config, 1);
decoder.decode(config, 1, debug);
config.print(stdout);
} catch(std::exception & e) {util::error(e);}
......
......@@ -26,8 +26,8 @@ class Trainer
public :
Trainer(ReadingMachine & machine);
void createDataset(SubConfig & goldConfig);
float epoch();
void createDataset(SubConfig & goldConfig, bool debug);
float epoch(bool printAdvancement);
};
......
......@@ -5,7 +5,7 @@ Trainer::Trainer(ReadingMachine & machine) : machine(machine)
{
}
void Trainer::createDataset(SubConfig & config)
void Trainer::createDataset(SubConfig & config, bool debug)
{
config.addPredicted(machine.getPredicted());
config.setState(machine.getStrategy().getInitialState());
......@@ -15,6 +15,9 @@ void Trainer::createDataset(SubConfig & config)
while (true)
{
if (debug)
config.printForDebug(stderr);
auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!transition)
{
......@@ -57,7 +60,7 @@ void Trainer::createDataset(SubConfig & config)
sparseOptimizer.reset(new torch::optim::SparseAdam(machine.getClassifier()->getNN()->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5)));
}
float Trainer::epoch()
float Trainer::epoch(bool printAdvancement)
{
constexpr int printInterval = 2000;
float totalLoss = 0.0;
......@@ -83,14 +86,17 @@ float Trainer::epoch()
denseOptimizer->step();
sparseOptimizer->step();
nbExamplesUntilPrint -= labels.size(0);
++currentBatchNumber;
if (nbExamplesUntilPrint <= 0)
if (printAdvancement)
{
nbExamplesUntilPrint = printInterval;
fmt::print(stderr, "\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar);
lossSoFar = 0;
nbExamplesUntilPrint -= labels.size(0);
++currentBatchNumber;
if (nbExamplesUntilPrint <= 0)
{
nbExamplesUntilPrint = printInterval;
fmt::print(stderr, "\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar);
lossSoFar = 0;
}
}
}
......
......@@ -21,6 +21,7 @@ po::options_description getOptionsDescription()
po::options_description opt("Optional");
opt.add_options()
("debug,d", "Print debuging infos on stderr")
("trainTXT", po::value<std::string>()->default_value(""),
"Raw text file of the training corpus")
("devTSV", po::value<std::string>()->default_value(""),
......@@ -70,6 +71,7 @@ int main(int argc, char * argv[])
auto devTsvFile = variables["devTSV"].as<std::string>();
auto devRawFile = variables["devTXT"].as<std::string>();
auto nbEpoch = variables["nbEpochs"].as<int>();
bool debug = variables.count("debug") == 0 ? false : true;
ReadingMachine machine(machinePath.string());
......@@ -77,7 +79,7 @@ int main(int argc, char * argv[])
SubConfig config(goldConfig);
Trainer trainer(machine);
trainer.createDataset(config);
trainer.createDataset(config, debug);
Decoder decoder(machine);
BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile);
......@@ -86,10 +88,13 @@ int main(int argc, char * argv[])
for (int i = 0; i < nbEpoch; i++)
{
float loss = trainer.epoch();
float loss = trainer.epoch(!debug);
auto devConfig = devGoldConfig;
fmt::print(stderr, "\r{:80}\rDecoding dev...", " ");
decoder.decode(devConfig, 1);
if (debug)
fmt::print(stderr, "Decoding dev :\n");
else
fmt::print(stderr, "\r{:80}\rDecoding dev...", " ");
decoder.decode(devConfig, 1, debug);
decoder.evaluate(devConfig, modelPath, devTsvFile);
float devScore = decoder.getF1Score("UPOS");
bool saved = devScore > bestDevScore;
......@@ -98,7 +103,10 @@ int main(int argc, char * argv[])
bestDevScore = devScore;
machine.save();
}
fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : "");
if (debug)
fmt::print(stderr, "Epoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : "");
else
fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : "");
}
return 0;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment