Skip to content
Snippets Groups Projects
Commit b743e4d3 authored by Franck Dary's avatar Franck Dary
Browse files

Added option to decide if dev must be evaluated or not

parent a4fe4f0b
No related branches found
No related tags found
No related merge requests found
...@@ -7,7 +7,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) ...@@ -7,7 +7,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement) void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement)
{ {
machine.getClassifier()->getNN()->train(false); torch::AutoGradMode useGrad(false);
config.addPredicted(machine.getPredicted()); config.addPredicted(machine.getPredicted());
constexpr int printInterval = 50; constexpr int printInterval = 50;
...@@ -88,8 +88,6 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool ...@@ -88,8 +88,6 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
if (debug) if (debug)
fmt::print(stderr, "Forcing EOS transition\n"); fmt::print(stderr, "Forcing EOS transition\n");
} }
machine.getClassifier()->getNN()->train(true);
} }
float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const
......
...@@ -16,16 +16,24 @@ class Trainer ...@@ -16,16 +16,24 @@ class Trainer
ReadingMachine & machine; ReadingMachine & machine;
DataLoader dataLoader{nullptr}; DataLoader dataLoader{nullptr};
DataLoader devDataLoader{nullptr};
std::unique_ptr<torch::optim::Adam> optimizer; std::unique_ptr<torch::optim::Adam> optimizer;
std::size_t epochNumber{0}; std::size_t epochNumber{0};
int batchSize{50}; int batchSize{50};
int nbExamples{0}; int nbExamples{0};
private :
void extractExamples(SubConfig & config, bool debug, std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes);
float processDataset(DataLoader & loader, bool train, bool printAdvancement);
public : public :
Trainer(ReadingMachine & machine); Trainer(ReadingMachine & machine);
void createDataset(SubConfig & goldConfig, bool debug); void createDataset(SubConfig & goldConfig, bool debug);
void createDevDataset(SubConfig & goldConfig, bool debug);
float epoch(bool printAdvancement); float epoch(bool printAdvancement);
float evalOnDev(bool printAdvancement);
}; };
#endif #endif
...@@ -7,12 +7,33 @@ Trainer::Trainer(ReadingMachine & machine) : machine(machine) ...@@ -7,12 +7,33 @@ Trainer::Trainer(ReadingMachine & machine) : machine(machine)
void Trainer::createDataset(SubConfig & config, bool debug) void Trainer::createDataset(SubConfig & config, bool debug)
{ {
config.addPredicted(machine.getPredicted()); std::vector<torch::Tensor> contexts;
config.setState(machine.getStrategy().getInitialState()); std::vector<torch::Tensor> classes;
extractExamples(config, debug, contexts, classes);
nbExamples = classes.size();
dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.001).amsgrad(true).beta1(0.9).beta2(0.999)));
}
void Trainer::createDevDataset(SubConfig & config, bool debug)
{
std::vector<torch::Tensor> contexts; std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes; std::vector<torch::Tensor> classes;
extractExamples(config, debug, contexts, classes);
devDataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
}
void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes)
{
config.addPredicted(machine.getPredicted());
config.setState(machine.getStrategy().getInitialState());
while (true) while (true)
{ {
if (debug) if (debug)
...@@ -59,15 +80,9 @@ void Trainer::createDataset(SubConfig & config, bool debug) ...@@ -59,15 +80,9 @@ void Trainer::createDataset(SubConfig & config, bool debug)
if (config.needsUpdate()) if (config.needsUpdate())
config.update(); config.update();
} }
nbExamples = classes.size();
dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.001).amsgrad(true).beta1(0.9).beta2(0.999)));
} }
float Trainer::epoch(bool printAdvancement) float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement)
{ {
constexpr int printInterval = 50; constexpr int printInterval = 50;
int nbExamplesProcessed = 0; int nbExamplesProcessed = 0;
...@@ -75,12 +90,15 @@ float Trainer::epoch(bool printAdvancement) ...@@ -75,12 +90,15 @@ float Trainer::epoch(bool printAdvancement)
float lossSoFar = 0.0; float lossSoFar = 0.0;
int currentBatchNumber = 0; int currentBatchNumber = 0;
torch::AutoGradMode useGrad(train);
auto lossFct = torch::nn::CrossEntropyLoss(); auto lossFct = torch::nn::CrossEntropyLoss();
auto pastTime = std::chrono::high_resolution_clock::now(); auto pastTime = std::chrono::high_resolution_clock::now();
for (auto & batch : *dataLoader) for (auto & batch : *loader)
{ {
if (train)
optimizer->zero_grad(); optimizer->zero_grad();
auto data = batch.data; auto data = batch.data;
...@@ -99,8 +117,11 @@ float Trainer::epoch(bool printAdvancement) ...@@ -99,8 +117,11 @@ float Trainer::epoch(bool printAdvancement)
lossSoFar += loss.item<float>(); lossSoFar += loss.item<float>();
} catch(std::exception & e) {util::myThrow(e.what());} } catch(std::exception & e) {util::myThrow(e.what());}
if (train)
{
loss.backward(); loss.backward();
optimizer->step(); optimizer->step();
}
if (printAdvancement) if (printAdvancement)
{ {
...@@ -122,3 +143,13 @@ float Trainer::epoch(bool printAdvancement) ...@@ -122,3 +143,13 @@ float Trainer::epoch(bool printAdvancement)
return totalLoss; return totalLoss;
} }
float Trainer::epoch(bool printAdvancement)
{
return processDataset(dataLoader, true, printAdvancement);
}
float Trainer::evalOnDev(bool printAdvancement)
{
return processDataset(devDataLoader, false, printAdvancement);
}
...@@ -24,6 +24,7 @@ po::options_description getOptionsDescription() ...@@ -24,6 +24,7 @@ po::options_description getOptionsDescription()
opt.add_options() opt.add_options()
("debug,d", "Print debuging infos on stderr") ("debug,d", "Print debuging infos on stderr")
("silent", "Don't print speed and progress") ("silent", "Don't print speed and progress")
("devScore", "Compute score on dev instead of loss (slower)")
("trainTXT", po::value<std::string>()->default_value(""), ("trainTXT", po::value<std::string>()->default_value(""),
"Raw text file of the training corpus") "Raw text file of the training corpus")
("devTSV", po::value<std::string>()->default_value(""), ("devTSV", po::value<std::string>()->default_value(""),
...@@ -75,6 +76,7 @@ int main(int argc, char * argv[]) ...@@ -75,6 +76,7 @@ int main(int argc, char * argv[])
auto nbEpoch = variables["nbEpochs"].as<int>(); auto nbEpoch = variables["nbEpochs"].as<int>();
bool debug = variables.count("debug") == 0 ? false : true; bool debug = variables.count("debug") == 0 ? false : true;
bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
bool computeDevScore = variables.count("devScore") == 0 ? false : true;
fmt::print(stderr, "Training using device : {}\n", NeuralNetworkImpl::device.str()); fmt::print(stderr, "Training using device : {}\n", NeuralNetworkImpl::device.str());
...@@ -84,38 +86,58 @@ int main(int argc, char * argv[]) ...@@ -84,38 +86,58 @@ int main(int argc, char * argv[])
ReadingMachine machine(machinePath.string()); ReadingMachine machine(machinePath.string());
BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile); BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile);
SubConfig config(goldConfig); SubConfig config(goldConfig);
Trainer trainer(machine); Trainer trainer(machine);
trainer.createDataset(config, debug); trainer.createDataset(config, debug);
if (!computeDevScore)
{
SubConfig devConfig(devGoldConfig);
trainer.createDevDataset(devConfig, debug);
}
Decoder decoder(machine); Decoder decoder(machine);
BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile);
float bestDevScore = 0; float bestDevScore = computeDevScore ? 0 : 100;
for (int i = 0; i < nbEpoch; i++) for (int i = 0; i < nbEpoch; i++)
{ {
float loss = trainer.epoch(printAdvancement); float loss = trainer.epoch(printAdvancement);
machine.getStrategy().reset(); machine.getStrategy().reset();
auto devConfig = devGoldConfig;
if (debug) if (debug)
fmt::print(stderr, "Decoding dev :\n"); fmt::print(stderr, "Decoding dev :\n");
std::vector<std::pair<float,std::string>> devScores;
if (computeDevScore)
{
auto devConfig = devGoldConfig;
decoder.decode(devConfig, 1, debug, printAdvancement); decoder.decode(devConfig, 1, debug, printAdvancement);
machine.getStrategy().reset(); machine.getStrategy().reset();
decoder.evaluate(devConfig, modelPath, devTsvFile); decoder.evaluate(devConfig, modelPath, devTsvFile);
std::vector<std::pair<float,std::string>> devScores = decoder.getF1Scores(machine.getPredicted()); devScores = decoder.getF1Scores(machine.getPredicted());
}
else
{
float devLoss = trainer.evalOnDev(printAdvancement);
devScores.emplace_back(std::make_pair(devLoss, "Loss"));
}
std::string devScoresStr = ""; std::string devScoresStr = "";
float devScoreMean = 0; float devScoreMean = 0;
for (auto & score : devScores) for (auto & score : devScores)
{ {
devScoresStr += fmt::format("{}({:5.2f}%),", score.second, score.first); if (computeDevScore)
devScoresStr += fmt::format("{}({:5.2f}{}),", score.second, score.first, computeDevScore ? "%" : "");
else
devScoresStr += fmt::format("{}({:6.1f}{}),", score.second, score.first, computeDevScore ? "%" : "");
devScoreMean += score.first; devScoreMean += score.first;
} }
if (!devScoresStr.empty()) if (!devScoresStr.empty())
devScoresStr.pop_back(); devScoresStr.pop_back();
devScoreMean /= devScores.size(); devScoreMean /= devScores.size();
bool saved = devScoreMean > bestDevScore; bool saved = devScoreMean > bestDevScore;
if (!computeDevScore)
saved = devScoreMean < bestDevScore;
if (saved) if (saved)
{ {
bestDevScore = devScoreMean; bestDevScore = devScoreMean;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment