Skip to content
Snippets Groups Projects
Commit b743e4d3 authored by Franck Dary's avatar Franck Dary
Browse files

Added option to decide if dev must be evaluated or not

parent a4fe4f0b
No related branches found
No related tags found
No related merge requests found
......@@ -7,7 +7,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement)
{
machine.getClassifier()->getNN()->train(false);
torch::AutoGradMode useGrad(false);
config.addPredicted(machine.getPredicted());
constexpr int printInterval = 50;
......@@ -88,8 +88,6 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
if (debug)
fmt::print(stderr, "Forcing EOS transition\n");
}
machine.getClassifier()->getNN()->train(true);
}
float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const
......
......@@ -16,16 +16,24 @@ class Trainer
ReadingMachine & machine;
DataLoader dataLoader{nullptr};
DataLoader devDataLoader{nullptr};
std::unique_ptr<torch::optim::Adam> optimizer;
std::size_t epochNumber{0};
int batchSize{50};
int nbExamples{0};
private :
void extractExamples(SubConfig & config, bool debug, std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes);
float processDataset(DataLoader & loader, bool train, bool printAdvancement);
public :
Trainer(ReadingMachine & machine);
void createDataset(SubConfig & goldConfig, bool debug);
void createDevDataset(SubConfig & goldConfig, bool debug);
float epoch(bool printAdvancement);
float evalOnDev(bool printAdvancement);
};
#endif
......@@ -7,12 +7,33 @@ Trainer::Trainer(ReadingMachine & machine) : machine(machine)
void Trainer::createDataset(SubConfig & config, bool debug)
{
config.addPredicted(machine.getPredicted());
config.setState(machine.getStrategy().getInitialState());
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
extractExamples(config, debug, contexts, classes);
nbExamples = classes.size();
dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.001).amsgrad(true).beta1(0.9).beta2(0.999)));
}
void Trainer::createDevDataset(SubConfig & config, bool debug)
{
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
extractExamples(config, debug, contexts, classes);
devDataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
}
void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes)
{
config.addPredicted(machine.getPredicted());
config.setState(machine.getStrategy().getInitialState());
while (true)
{
if (debug)
......@@ -59,15 +80,9 @@ void Trainer::createDataset(SubConfig & config, bool debug)
if (config.needsUpdate())
config.update();
}
nbExamples = classes.size();
dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.001).amsgrad(true).beta1(0.9).beta2(0.999)));
}
float Trainer::epoch(bool printAdvancement)
float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement)
{
constexpr int printInterval = 50;
int nbExamplesProcessed = 0;
......@@ -75,13 +90,16 @@ float Trainer::epoch(bool printAdvancement)
float lossSoFar = 0.0;
int currentBatchNumber = 0;
torch::AutoGradMode useGrad(train);
auto lossFct = torch::nn::CrossEntropyLoss();
auto pastTime = std::chrono::high_resolution_clock::now();
for (auto & batch : *dataLoader)
for (auto & batch : *loader)
{
optimizer->zero_grad();
if (train)
optimizer->zero_grad();
auto data = batch.data;
auto labels = batch.target.squeeze();
......@@ -99,8 +117,11 @@ float Trainer::epoch(bool printAdvancement)
lossSoFar += loss.item<float>();
} catch(std::exception & e) {util::myThrow(e.what());}
loss.backward();
optimizer->step();
if (train)
{
loss.backward();
optimizer->step();
}
if (printAdvancement)
{
......@@ -122,3 +143,13 @@ float Trainer::epoch(bool printAdvancement)
return totalLoss;
}
float Trainer::epoch(bool printAdvancement)
{
return processDataset(dataLoader, true, printAdvancement);
}
float Trainer::evalOnDev(bool printAdvancement)
{
return processDataset(devDataLoader, false, printAdvancement);
}
......@@ -24,6 +24,7 @@ po::options_description getOptionsDescription()
opt.add_options()
("debug,d", "Print debuging infos on stderr")
("silent", "Don't print speed and progress")
("devScore", "Compute score on dev instead of loss (slower)")
("trainTXT", po::value<std::string>()->default_value(""),
"Raw text file of the training corpus")
("devTSV", po::value<std::string>()->default_value(""),
......@@ -75,6 +76,7 @@ int main(int argc, char * argv[])
auto nbEpoch = variables["nbEpochs"].as<int>();
bool debug = variables.count("debug") == 0 ? false : true;
bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
bool computeDevScore = variables.count("devScore") == 0 ? false : true;
fmt::print(stderr, "Training using device : {}\n", NeuralNetworkImpl::device.str());
......@@ -84,38 +86,58 @@ int main(int argc, char * argv[])
ReadingMachine machine(machinePath.string());
BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile);
SubConfig config(goldConfig);
Trainer trainer(machine);
trainer.createDataset(config, debug);
if (!computeDevScore)
{
SubConfig devConfig(devGoldConfig);
trainer.createDevDataset(devConfig, debug);
}
Decoder decoder(machine);
BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile);
float bestDevScore = 0;
float bestDevScore = computeDevScore ? 0 : 100;
for (int i = 0; i < nbEpoch; i++)
{
float loss = trainer.epoch(printAdvancement);
machine.getStrategy().reset();
auto devConfig = devGoldConfig;
if (debug)
fmt::print(stderr, "Decoding dev :\n");
decoder.decode(devConfig, 1, debug, printAdvancement);
machine.getStrategy().reset();
decoder.evaluate(devConfig, modelPath, devTsvFile);
std::vector<std::pair<float,std::string>> devScores = decoder.getF1Scores(machine.getPredicted());
std::vector<std::pair<float,std::string>> devScores;
if (computeDevScore)
{
auto devConfig = devGoldConfig;
decoder.decode(devConfig, 1, debug, printAdvancement);
machine.getStrategy().reset();
decoder.evaluate(devConfig, modelPath, devTsvFile);
devScores = decoder.getF1Scores(machine.getPredicted());
}
else
{
float devLoss = trainer.evalOnDev(printAdvancement);
devScores.emplace_back(std::make_pair(devLoss, "Loss"));
}
std::string devScoresStr = "";
float devScoreMean = 0;
for (auto & score : devScores)
{
devScoresStr += fmt::format("{}({:5.2f}%),", score.second, score.first);
if (computeDevScore)
devScoresStr += fmt::format("{}({:5.2f}{}),", score.second, score.first, computeDevScore ? "%" : "");
else
devScoresStr += fmt::format("{}({:6.1f}{}),", score.second, score.first, computeDevScore ? "%" : "");
devScoreMean += score.first;
}
if (!devScoresStr.empty())
devScoresStr.pop_back();
devScoreMean /= devScores.size();
bool saved = devScoreMean > bestDevScore;
if (!computeDevScore)
saved = devScoreMean < bestDevScore;
if (saved)
{
bestDevScore = devScoreMean;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment