diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index 4108d3d310c43931e7125b3163a1316d62cfa281..ed41b2937a07322912f9516283bb19937f8d097b 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -9,11 +9,16 @@ class Classifier { private : + std::vector<std::string> knownOptimizers{ + "Adam {lr beta1 beta2 eps decay amsgrad}", + "Adagrad {lr lr_decay weight_decay eps}", + }; + std::string name; std::map<std::string, std::unique_ptr<TransitionSet>> transitionSets; std::map<std::string, float> lossMultipliers; std::shared_ptr<NeuralNetworkImpl> nn; - std::unique_ptr<torch::optim::Adam> optimizer; + std::unique_ptr<torch::optim::Optimizer> optimizer; std::string optimizerType, optimizerParameters; std::string state; @@ -32,7 +37,7 @@ class Classifier void resetOptimizer(); void loadOptimizer(std::filesystem::path path); void saveOptimizer(std::filesystem::path path); - torch::optim::Adam & getOptimizer(); + torch::optim::Optimizer & getOptimizer(); void setState(const std::string & state); float getLossMultiplier(); }; diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index b7cc0c4da15e8aca82656eb3dbeab26e47279a95..753c40e35bed8ce60b2245004027feae05e69a11 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -117,7 +117,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) optimizerType = sm.str(1); optimizerParameters = sm.str(2); })) - util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}")); + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) " + util::join("|", knownOptimizers))); } void Classifier::loadOptimizer(std::filesystem::path path) @@ -130,7 +130,7 @@ void Classifier::saveOptimizer(std::filesystem::path path) torch::save(*optimizer, path); } -torch::optim::Adam & Classifier::getOptimizer() +torch::optim::Optimizer & Classifier::getOptimizer() { return *optimizer; } @@ -162,7 +162,7 @@ void Classifier::initModular(const std::vector<std::string> & definition, std::s void Classifier::resetOptimizer() { - std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'"; + std::string expected = "expected '(Optimizer :) " + util::join("| ", knownOptimizers); if (optimizerType == "Adam") { auto splited = util::split(optimizerParameters, ' '); @@ -171,6 +171,14 @@ void Classifier::resetOptimizer() optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").betas({std::stof(splited[1]),std::stof(splited[2])}).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4])))); } + else if (optimizerType == "Adagrad") + { + auto splited = util::split(optimizerParameters, ' '); + if (splited.size() != 4) + util::myThrow(expected); + + optimizer.reset(new torch::optim::Adagrad(getNN()->parameters(), torch::optim::AdagradOptions(std::stof(splited[0])).lr_decay(std::stof(splited[1])).eps(std::stof(splited[3])).weight_decay(std::stof(splited[2])))); + } else util::myThrow(expected); }