diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 147396e0bd5eccb60c4740bca66f58eedc1a56a5..cb9937edbf78a1bb657a81c544468a8257f3c708 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -7,7 +7,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement) { - machine.getClassifier()->getNN()->train(false); + torch::AutoGradMode useGrad(false); config.addPredicted(machine.getPredicted()); constexpr int printInterval = 50; @@ -88,8 +88,6 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool if (debug) fmt::print(stderr, "Forcing EOS transition\n"); } - - machine.getClassifier()->getNN()->train(true); } float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 6e889171c5f1fa461f333605446fb8545d287270..e04f3e37dcee7bd29ea47dd226d144d9574d1e40 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -16,16 +16,24 @@ class Trainer ReadingMachine & machine; DataLoader dataLoader{nullptr}; + DataLoader devDataLoader{nullptr}; std::unique_ptr<torch::optim::Adam> optimizer; std::size_t epochNumber{0}; int batchSize{50}; int nbExamples{0}; + private : + + void extractExamples(SubConfig & config, bool debug, std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes); + float processDataset(DataLoader & loader, bool train, bool printAdvancement); + public : Trainer(ReadingMachine & machine); void createDataset(SubConfig & goldConfig, bool debug); + void createDevDataset(SubConfig & goldConfig, bool debug); float epoch(bool printAdvancement); + float evalOnDev(bool printAdvancement); }; #endif diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 106df3ffbdf3d90ab397d78a6fbff11619e77dc9..e2572872814364cb643ace0a882f8ec984eb4be9 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -7,12 +7,33 @@ Trainer::Trainer(ReadingMachine & machine) : machine(machine) void Trainer::createDataset(SubConfig & config, bool debug) { - config.addPredicted(machine.getPredicted()); - config.setState(machine.getStrategy().getInitialState()); + std::vector<torch::Tensor> contexts; + std::vector<torch::Tensor> classes; + + extractExamples(config, debug, contexts, classes); + nbExamples = classes.size(); + + dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); + + optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.001).amsgrad(true).beta1(0.9).beta2(0.999))); +} + +void Trainer::createDevDataset(SubConfig & config, bool debug) +{ std::vector<torch::Tensor> contexts; std::vector<torch::Tensor> classes; + extractExamples(config, debug, contexts, classes); + + devDataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); +} + +void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch::Tensor> & contexts, std::vector<torch::Tensor> & classes) +{ + config.addPredicted(machine.getPredicted()); + config.setState(machine.getStrategy().getInitialState()); + while (true) { if (debug) @@ -59,15 +80,9 @@ void Trainer::createDataset(SubConfig & config, bool debug) if (config.needsUpdate()) config.update(); } - - nbExamples = classes.size(); - - dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); - - optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.001).amsgrad(true).beta1(0.9).beta2(0.999))); } -float Trainer::epoch(bool printAdvancement) +float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement) { constexpr int printInterval = 50; int nbExamplesProcessed = 0; @@ -75,13 +90,16 @@ float Trainer::epoch(bool printAdvancement) float lossSoFar = 0.0; int currentBatchNumber = 0; + torch::AutoGradMode useGrad(train); + auto lossFct = torch::nn::CrossEntropyLoss(); auto pastTime = std::chrono::high_resolution_clock::now(); - for (auto & batch : *dataLoader) + for (auto & batch : *loader) { - optimizer->zero_grad(); + if (train) + optimizer->zero_grad(); auto data = batch.data; auto labels = batch.target.squeeze(); @@ -99,8 +117,11 @@ float Trainer::epoch(bool printAdvancement) lossSoFar += loss.item<float>(); } catch(std::exception & e) {util::myThrow(e.what());} - loss.backward(); - optimizer->step(); + if (train) + { + loss.backward(); + optimizer->step(); + } if (printAdvancement) { @@ -122,3 +143,13 @@ float Trainer::epoch(bool printAdvancement) return totalLoss; } +float Trainer::epoch(bool printAdvancement) +{ + return processDataset(dataLoader, true, printAdvancement); +} + +float Trainer::evalOnDev(bool printAdvancement) +{ + return processDataset(devDataLoader, false, printAdvancement); +} + diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp index dc48e459b3357c20e07ad772acc7af839ea262b5..602553786a7c2d9b199914113ca45ad12e89d735 100644 --- a/trainer/src/macaon_train.cpp +++ b/trainer/src/macaon_train.cpp @@ -24,6 +24,7 @@ po::options_description getOptionsDescription() opt.add_options() ("debug,d", "Print debuging infos on stderr") ("silent", "Don't print speed and progress") + ("devScore", "Compute score on dev instead of loss (slower)") ("trainTXT", po::value<std::string>()->default_value(""), "Raw text file of the training corpus") ("devTSV", po::value<std::string>()->default_value(""), @@ -75,6 +76,7 @@ int main(int argc, char * argv[]) auto nbEpoch = variables["nbEpochs"].as<int>(); bool debug = variables.count("debug") == 0 ? false : true; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; + bool computeDevScore = variables.count("devScore") == 0 ? false : true; fmt::print(stderr, "Training using device : {}\n", NeuralNetworkImpl::device.str()); @@ -84,38 +86,58 @@ int main(int argc, char * argv[]) ReadingMachine machine(machinePath.string()); BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile); + BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile); SubConfig config(goldConfig); Trainer trainer(machine); trainer.createDataset(config, debug); + if (!computeDevScore) + { + SubConfig devConfig(devGoldConfig); + trainer.createDevDataset(devConfig, debug); + } Decoder decoder(machine); - BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile); - float bestDevScore = 0; + float bestDevScore = computeDevScore ? 0 : 100; for (int i = 0; i < nbEpoch; i++) { float loss = trainer.epoch(printAdvancement); machine.getStrategy().reset(); - auto devConfig = devGoldConfig; if (debug) fmt::print(stderr, "Decoding dev :\n"); - decoder.decode(devConfig, 1, debug, printAdvancement); - machine.getStrategy().reset(); - decoder.evaluate(devConfig, modelPath, devTsvFile); - std::vector<std::pair<float,std::string>> devScores = decoder.getF1Scores(machine.getPredicted()); + std::vector<std::pair<float,std::string>> devScores; + if (computeDevScore) + { + auto devConfig = devGoldConfig; + decoder.decode(devConfig, 1, debug, printAdvancement); + machine.getStrategy().reset(); + decoder.evaluate(devConfig, modelPath, devTsvFile); + devScores = decoder.getF1Scores(machine.getPredicted()); + } + else + { + float devLoss = trainer.evalOnDev(printAdvancement); + devScores.emplace_back(std::make_pair(devLoss, "Loss")); + } + std::string devScoresStr = ""; float devScoreMean = 0; for (auto & score : devScores) { - devScoresStr += fmt::format("{}({:5.2f}%),", score.second, score.first); + if (computeDevScore) + devScoresStr += fmt::format("{}({:5.2f}{}),", score.second, score.first, computeDevScore ? "%" : ""); + else + devScoresStr += fmt::format("{}({:6.1f}{}),", score.second, score.first, computeDevScore ? "%" : ""); devScoreMean += score.first; } if (!devScoresStr.empty()) devScoresStr.pop_back(); devScoreMean /= devScores.size(); bool saved = devScoreMean > bestDevScore; + if (!computeDevScore) + saved = devScoreMean < bestDevScore; if (saved) { bestDevScore = devScoreMean;