Skip to content
Snippets Groups Projects
Commit 344bd63a authored by Franck Dary's avatar Franck Dary
Browse files

Added Adagrad optimizer

parent aa79c9cf
No related branches found
No related tags found
No related merge requests found
......@@ -9,11 +9,16 @@ class Classifier
{
private :
std::vector<std::string> knownOptimizers{
"Adam {lr beta1 beta2 eps decay amsgrad}",
"Adagrad {lr lr_decay weight_decay eps}",
};
std::string name;
std::map<std::string, std::unique_ptr<TransitionSet>> transitionSets;
std::map<std::string, float> lossMultipliers;
std::shared_ptr<NeuralNetworkImpl> nn;
std::unique_ptr<torch::optim::Adam> optimizer;
std::unique_ptr<torch::optim::Optimizer> optimizer;
std::string optimizerType, optimizerParameters;
std::string state;
......@@ -32,7 +37,7 @@ class Classifier
void resetOptimizer();
void loadOptimizer(std::filesystem::path path);
void saveOptimizer(std::filesystem::path path);
torch::optim::Adam & getOptimizer();
torch::optim::Optimizer & getOptimizer();
void setState(const std::string & state);
float getLossMultiplier();
};
......
......@@ -117,7 +117,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
optimizerType = sm.str(1);
optimizerParameters = sm.str(2);
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}"));
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) " + util::join("|", knownOptimizers)));
}
void Classifier::loadOptimizer(std::filesystem::path path)
......@@ -130,7 +130,7 @@ void Classifier::saveOptimizer(std::filesystem::path path)
torch::save(*optimizer, path);
}
torch::optim::Adam & Classifier::getOptimizer()
torch::optim::Optimizer & Classifier::getOptimizer()
{
return *optimizer;
}
......@@ -162,7 +162,7 @@ void Classifier::initModular(const std::vector<std::string> & definition, std::s
void Classifier::resetOptimizer()
{
std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'";
std::string expected = "expected '(Optimizer :) " + util::join("| ", knownOptimizers);
if (optimizerType == "Adam")
{
auto splited = util::split(optimizerParameters, ' ');
......@@ -171,6 +171,14 @@ void Classifier::resetOptimizer()
optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").betas({std::stof(splited[1]),std::stof(splited[2])}).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4]))));
}
else if (optimizerType == "Adagrad")
{
auto splited = util::split(optimizerParameters, ' ');
if (splited.size() != 4)
util::myThrow(expected);
optimizer.reset(new torch::optim::Adagrad(getNN()->parameters(), torch::optim::AdagradOptions(std::stof(splited[0])).lr_decay(std::stof(splited[1])).eps(std::stof(splited[3])).weight_decay(std::stof(splited[2]))));
}
else
util::myThrow(expected);
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment