diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index 29ff70469d333299dbbc553d44861976c0e800db..f3f8726ba0ae557ed7de2740fa34d6d228143b02 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -13,6 +13,7 @@ class Classifier std::map<std::string,std::unique_ptr<TransitionSet>> transitionSets; std::shared_ptr<NeuralNetworkImpl> nn; std::unique_ptr<torch::optim::Adam> optimizer; + std::string optimizerType, optimizerParameters; std::string state; private : @@ -27,6 +28,7 @@ class Classifier NeuralNetwork & getNN(); const std::string & getName() const; int getNbParameters() const; + void resetOptimizer(); void loadOptimizer(std::filesystem::path path); void saveOptimizer(std::filesystem::path path); torch::optim::Adam & getOptimizer(); diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index ee9e1e77801f3db44e72f4e49005d08a9cb7fa18..b33a2dfcf338ce26ba5439e15a8921234c4a86a4 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -92,17 +92,8 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Optimizer :|)(?:(?:\\s|\\t)*)(.*) \\{(.*)\\}"), definition[curIndex], [&curIndex,this](auto sm) { - std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'"; - if (sm.str(1) == "Adam") - { - auto splited = util::split(sm.str(2), ' '); - if (splited.size() != 6 or (splited.back() != "false" and splited.back() != "true")) - util::myThrow(expected); - - optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").betas({std::stof(splited[1]),std::stof(splited[2])}).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4])))); - } - else - util::myThrow(expected); + optimizerType = sm.str(1); + optimizerParameters = sm.str(2); })) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}")); } @@ -147,3 +138,18 @@ void Classifier::initModular(const std::vector<std::string> & definition, std::s this->nn.reset(new ModularNetworkImpl(nbOutputsPerState, modulesDefinitions)); } +void Classifier::resetOptimizer() +{ + std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'"; + if (optimizerType == "Adam") + { + auto splited = util::split(optimizerParameters, ' '); + if (splited.size() != 6 or (splited.back() != "false" and splited.back() != "true")) + util::myThrow(expected); + + optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").betas({std::stof(splited[1]),std::stof(splited[2])}).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4])))); + } + else + util::myThrow(expected); +} + diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 86acd6b662778321277fbe7e9254fa49147b1d61..6832e3b6c33f310b33d915e758200bbdcc623bcc 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -167,6 +167,7 @@ int MacaonTrain::main() if (!computeDevScore) trainer.createDevDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, dynamicOracleInterval); + machine.getClassifier()->resetOptimizer(); auto optimizerCheckpoint = machinePath.parent_path() / "optimizer.pt"; if (std::filesystem::exists(trainInfos)) machine.getClassifier()->loadOptimizer(optimizerCheckpoint);