Skip to content
Snippets Groups Projects
Commit 344bd63a authored by Franck Dary's avatar Franck Dary
Browse files

Added Adagrad optimizer

parent aa79c9cf
No related branches found
No related tags found
No related merge requests found
...@@ -9,11 +9,16 @@ class Classifier ...@@ -9,11 +9,16 @@ class Classifier
{ {
private : private :
std::vector<std::string> knownOptimizers{
"Adam {lr beta1 beta2 eps decay amsgrad}",
"Adagrad {lr lr_decay weight_decay eps}",
};
std::string name; std::string name;
std::map<std::string, std::unique_ptr<TransitionSet>> transitionSets; std::map<std::string, std::unique_ptr<TransitionSet>> transitionSets;
std::map<std::string, float> lossMultipliers; std::map<std::string, float> lossMultipliers;
std::shared_ptr<NeuralNetworkImpl> nn; std::shared_ptr<NeuralNetworkImpl> nn;
std::unique_ptr<torch::optim::Adam> optimizer; std::unique_ptr<torch::optim::Optimizer> optimizer;
std::string optimizerType, optimizerParameters; std::string optimizerType, optimizerParameters;
std::string state; std::string state;
...@@ -32,7 +37,7 @@ class Classifier ...@@ -32,7 +37,7 @@ class Classifier
void resetOptimizer(); void resetOptimizer();
void loadOptimizer(std::filesystem::path path); void loadOptimizer(std::filesystem::path path);
void saveOptimizer(std::filesystem::path path); void saveOptimizer(std::filesystem::path path);
torch::optim::Adam & getOptimizer(); torch::optim::Optimizer & getOptimizer();
void setState(const std::string & state); void setState(const std::string & state);
float getLossMultiplier(); float getLossMultiplier();
}; };
......
...@@ -117,7 +117,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) ...@@ -117,7 +117,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
optimizerType = sm.str(1); optimizerType = sm.str(1);
optimizerParameters = sm.str(2); optimizerParameters = sm.str(2);
})) }))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}")); util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) " + util::join("|", knownOptimizers)));
} }
void Classifier::loadOptimizer(std::filesystem::path path) void Classifier::loadOptimizer(std::filesystem::path path)
...@@ -130,7 +130,7 @@ void Classifier::saveOptimizer(std::filesystem::path path) ...@@ -130,7 +130,7 @@ void Classifier::saveOptimizer(std::filesystem::path path)
torch::save(*optimizer, path); torch::save(*optimizer, path);
} }
torch::optim::Adam & Classifier::getOptimizer() torch::optim::Optimizer & Classifier::getOptimizer()
{ {
return *optimizer; return *optimizer;
} }
...@@ -162,7 +162,7 @@ void Classifier::initModular(const std::vector<std::string> & definition, std::s ...@@ -162,7 +162,7 @@ void Classifier::initModular(const std::vector<std::string> & definition, std::s
void Classifier::resetOptimizer() void Classifier::resetOptimizer()
{ {
std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'"; std::string expected = "expected '(Optimizer :) " + util::join("| ", knownOptimizers);
if (optimizerType == "Adam") if (optimizerType == "Adam")
{ {
auto splited = util::split(optimizerParameters, ' '); auto splited = util::split(optimizerParameters, ' ');
...@@ -171,6 +171,14 @@ void Classifier::resetOptimizer() ...@@ -171,6 +171,14 @@ void Classifier::resetOptimizer()
optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").betas({std::stof(splited[1]),std::stof(splited[2])}).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4])))); optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").betas({std::stof(splited[1]),std::stof(splited[2])}).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4]))));
} }
else if (optimizerType == "Adagrad")
{
auto splited = util::split(optimizerParameters, ' ');
if (splited.size() != 4)
util::myThrow(expected);
optimizer.reset(new torch::optim::Adagrad(getNN()->parameters(), torch::optim::AdagradOptions(std::stof(splited[0])).lr_decay(std::stof(splited[1])).eps(std::stof(splited[3])).weight_decay(std::stof(splited[2]))));
}
else else
util::myThrow(expected); util::myThrow(expected);
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment