diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index 7e423a3812ae3145759733a7feb1de080b3dbf71..2bb3af9347d2e25f66ac094738cec22978556928 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -12,6 +12,7 @@ class Classifier std::string name; std::unique_ptr<TransitionSet> transitionSet; std::shared_ptr<NeuralNetworkImpl> nn; + std::unique_ptr<torch::optim::Adam> optimizer; private : @@ -25,6 +26,9 @@ class Classifier NeuralNetwork & getNN(); const std::string & getName() const; int getNbParameters() const; + void loadOptimizer(std::filesystem::path path); + void saveOptimizer(std::filesystem::path path); + torch::optim::Adam & getOptimizer(); }; #endif diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 4e7212fa85e63ddaefe145238068bb0ddf888920..957e21f164d573519029274ff3dba0eab1e7709d 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -65,6 +65,22 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, LSTM'", networkType)); this->nn->to(NeuralNetworkImpl::device); + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Optimizer :|)(?:(?:\\s|\\t)*)(.*) \\{(.*)\\}"), definition[curIndex], [&curIndex,this](auto sm) + { + std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'"; + if (sm.str(1) == "Adam") + { + auto splited = util::split(sm.str(2), ' '); + if (splited.size() != 6 or (splited.back() != "false" and splited.back() != "true")) + util::myThrow(expected); + + optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").beta1(std::stof(splited[1])).beta2(std::stof(splited[2])).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4])))); + } + else + util::myThrow(expected); + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}")); } void Classifier::initLSTM(const std::vector<std::string> & definition, std::size_t & curIndex) @@ -272,3 +288,19 @@ void Classifier::initLSTM(const std::vector<std::string> & definition, std::size this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), unknownValueThreshold, bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, rawInputLeftWindow, rawInputRightWindow, embeddingsSize, mlp, contextLSTMSize, focusedLSTMSize, rawInputLSTMSize, splitTransLSTMSize, nbLayers, bilstm, lstmDropout, treeEmbeddingColumns, treeEmbeddingBuffer, treeEmbeddingStack, treeEmbeddingNbElems, treeEmbeddingSize, embeddingsDropout)); } +void Classifier::loadOptimizer(std::filesystem::path path) +{ + torch::load(*optimizer, path); +} + +void Classifier::saveOptimizer(std::filesystem::path path) +{ + torch::save(*optimizer, path); +} + +torch::optim::Adam & Classifier::getOptimizer() +{ + return *optimizer; +} + + diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index a69b9b9a3a0f926b1117ab02929a43436e67b2f6..7994a2f200b323bd23eeab8a08859c258dbf89a1 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -19,7 +19,6 @@ class Trainer std::unique_ptr<Dataset> devDataset{nullptr}; DataLoader dataLoader{nullptr}; DataLoader devDataLoader{nullptr}; - std::unique_ptr<torch::optim::Adam> optimizer; std::size_t epochNumber{0}; int batchSize; @@ -36,8 +35,6 @@ class Trainer void createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval); float epoch(bool printAdvancement); float evalOnDev(bool printAdvancement); - void loadOptimizer(std::filesystem::path path); - void saveOptimizer(std::filesystem::path path); }; #endif diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 1278b6d8835fe8a80c4f0bc5fbcb7426786f542d..72ff7432e2e354b7adb9e4c5dd7b9477cb27e868 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -157,7 +157,7 @@ int MacaonTrain::main() auto optimizerCheckpoint = machinePath.parent_path() / "optimizer.pt"; if (std::filesystem::exists(trainInfos)) - trainer.loadOptimizer(optimizerCheckpoint); + machine.getClassifier()->loadOptimizer(optimizerCheckpoint); for (; currentEpoch < nbEpoch; currentEpoch++) { @@ -204,7 +204,7 @@ int MacaonTrain::main() machine.saveBest(); } machine.saveLast(); - trainer.saveOptimizer(optimizerCheckpoint); + machine.getClassifier()->saveOptimizer(optimizerCheckpoint); if (printAdvancement) fmt::print(stderr, "\r{:80}\r", ""); std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.4f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : ""); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 16918e25469d7396c05d1067bc71d0bff7063b2f..d6101221408b3795260a23e08a1e3c81de3b2deb 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -14,9 +14,6 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem trainDataset.reset(new Dataset(dir)); dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); - - if (optimizer.get() == nullptr) - optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.0005).amsgrad(true).beta1(0.9).beta2(0.999))); } void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) @@ -184,7 +181,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance for (auto & batch : *loader) { if (train) - optimizer->zero_grad(); + machine.getClassifier()->getOptimizer().zero_grad(); auto data = batch.first; auto labels = batch.second; @@ -205,7 +202,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance if (train) { loss.backward(); - optimizer->step(); + machine.getClassifier()->getOptimizer().step(); } totalNbExamplesProcessed += torch::numel(labels); @@ -245,13 +242,3 @@ float Trainer::evalOnDev(bool printAdvancement) return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value()); } -void Trainer::loadOptimizer(std::filesystem::path path) -{ - torch::load(*optimizer, path); -} - -void Trainer::saveOptimizer(std::filesystem::path path) -{ - torch::save(*optimizer, path); -} -