From e350f836f250184782e9667d1de28491d96aec2b Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Thu, 16 Apr 2020 18:41:04 +0200 Subject: [PATCH] Optimizer is now a member of Classifier, and its parameters can be given in rm file --- reading_machine/include/Classifier.hpp | 4 ++++ reading_machine/src/Classifier.cpp | 32 ++++++++++++++++++++++++++ trainer/include/Trainer.hpp | 3 --- trainer/src/MacaonTrain.cpp | 4 ++-- trainer/src/Trainer.cpp | 17 ++------------ 5 files changed, 40 insertions(+), 20 deletions(-) diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index 7e423a3..2bb3af9 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -12,6 +12,7 @@ class Classifier std::string name; std::unique_ptr<TransitionSet> transitionSet; std::shared_ptr<NeuralNetworkImpl> nn; + std::unique_ptr<torch::optim::Adam> optimizer; private : @@ -25,6 +26,9 @@ class Classifier NeuralNetwork & getNN(); const std::string & getName() const; int getNbParameters() const; + void loadOptimizer(std::filesystem::path path); + void saveOptimizer(std::filesystem::path path); + torch::optim::Adam & getOptimizer(); }; #endif diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 4e7212f..957e21f 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -65,6 +65,22 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, LSTM'", networkType)); this->nn->to(NeuralNetworkImpl::device); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Optimizer :|)(?:(?:\\s|\\t)*)(.*) \\{(.*)\\}"), definition[curIndex], [&curIndex,this](auto sm) + { + std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'"; + if (sm.str(1) == "Adam") + { + auto splited = util::split(sm.str(2), ' '); + if (splited.size() != 6 or (splited.back() != "false" and splited.back() != "true")) + util::myThrow(expected); + + optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").beta1(std::stof(splited[1])).beta2(std::stof(splited[2])).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4])))); + } + else + util::myThrow(expected); + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}")); } void Classifier::initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex) @@ -272,3 +288,19 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout)); } +void Classifier::loadOptimizer(std::filesystem::path path) +{ + torch::load(*optimizer, path); +} + +void Classifier::saveOptimizer(std::filesystem::path path) +{ + torch::save(*optimizer, path); +} + +torch::optim::Adam & Classifier::getOptimizer() +{ + return *optimizer; +} + + diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index a69b9b9..7994a2f 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -19,7 +19,6 @@ class Trainer std::unique_ptr<Dataset> devDataset{nullptr}; DataLoader dataLoader{nullptr}; DataLoader devDataLoader{nullptr}; - std::unique_ptr<torch::optim::Adam> optimizer; std::size_t epochNumber{0}; int batchSize; @@ -36,8 +35,6 @@ class Trainer void createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); float epoch(bool printAdvancement); float evalOnDev(bool printAdvancement); - void loadOptimizer(std::filesystem::path path); - void saveOptimizer(std::filesystem::path path); }; #endif diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 1278b6d..72ff743 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -157,7 +157,7 @@ int MacaonTrain::main() auto optimizerCheckpoint = machinePath.parent_path() / "optimizer.pt"; if (std::filesystem::exists(trainInfos)) - trainer.loadOptimizer(optimizerCheckpoint); + machine.getClassifier()->loadOptimizer(optimizerCheckpoint); for (; currentEpoch < nbEpoch; currentEpoch++) { @@ -204,7 +204,7 @@ int MacaonTrain::main() machine.saveBest(); } machine.saveLast(); - trainer.saveOptimizer(optimizerCheckpoint); + machine.getClassifier()->saveOptimizer(optimizerCheckpoint); if (printAdvancement) fmt::print(stderr, "\r{:80}\r", ""); std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.4f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : ""); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 16918e2..d610122 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -14,9 +14,6 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem trainDataset.reset(new Dataset(dir)); dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); - - if (optimizer.get() == nullptr) - optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.0005).amsgrad(true).beta1(0.9).beta2(0.999))); } void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) @@ -184,7 +181,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance for (auto & batch : *loader) { if (train) - optimizer->zero_grad(); + machine.getClassifier()->getOptimizer().zero_grad(); auto data = batch.first; auto labels = batch.second; @@ -205,7 +202,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance if (train) { loss.backward(); - optimizer->step(); + machine.getClassifier()->getOptimizer().step(); } totalNbExamplesProcessed += torch::numel(labels); @@ -245,13 +242,3 @@ float Trainer::evalOnDev(bool printAdvancement) return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value()); } -void Trainer::loadOptimizer(std::filesystem::path path) -{ - torch::load(*optimizer, path); -} - -void Trainer::saveOptimizer(std::filesystem::path path) -{ - torch::save(*optimizer, path); -} - -- GitLab