Skip to content
Snippets Groups Projects
Commit 7ed12f20 authored by Franck Dary's avatar Franck Dary
Browse files

Added a way to change the optimizer

parent 6bbc03e0
Branches
No related tags found
No related merge requests found
...@@ -92,7 +92,7 @@ class MLP ...@@ -92,7 +92,7 @@ class MLP
/// @brief The dynet model containing the parameters to be trained. /// @brief The dynet model containing the parameters to be trained.
dynet::ParameterCollection model; dynet::ParameterCollection model;
/// @brief The training algorithm that will be used. /// @brief The training algorithm that will be used.
std::unique_ptr<dynet::AmsgradTrainer> trainer; std::unique_ptr<dynet::Trainer> trainer;
/// @brief Whether the program is in train mode or not (only in train mode the parameters will be updated). /// @brief Whether the program is in train mode or not (only in train mode the parameters will be updated).
bool trainMode; bool trainMode;
/// @brief Must the Layer dropout rate be taken into account during the computations ? Usually it is only during the training step. /// @brief Must the Layer dropout rate be taken into account during the computations ? Usually it is only during the training step.
...@@ -205,6 +205,10 @@ class MLP ...@@ -205,6 +205,10 @@ class MLP
/// ///
/// @param output Where the topology will be printed. /// @param output Where the topology will be printed.
void printTopology(FILE * output); void printTopology(FILE * output);
/// @brief Allocate the correct trainer type depending on the program parameters.
///
/// @return A pointer to the newly allocated trainer.
dynet::Trainer * createTrainer();
}; };
#endif #endif
...@@ -111,7 +111,7 @@ MLP::MLP(int nbInputs, const std::string & topology, int nbOutputs) ...@@ -111,7 +111,7 @@ MLP::MLP(int nbInputs, const std::string & topology, int nbOutputs)
layers.emplace_back(layers.back().output_dim, nbOutputs, 0.0, Activation::LINEAR); layers.emplace_back(layers.back().output_dim, nbOutputs, 0.0, Activation::LINEAR);
trainer.reset(new dynet::AmsgradTrainer(model, ProgramParameters::learningRate, ProgramParameters::beta1, ProgramParameters::beta2, ProgramParameters::bias)); trainer.reset(createTrainer());
initDynet(); initDynet();
...@@ -124,6 +124,27 @@ MLP::MLP(int nbInputs, const std::string & topology, int nbOutputs) ...@@ -124,6 +124,27 @@ MLP::MLP(int nbInputs, const std::string & topology, int nbOutputs)
addLayerToModel(layer); addLayerToModel(layer);
} }
dynet::Trainer * MLP::createTrainer()
{
if (!trainMode)
return nullptr;
auto optimizer = noAccentLower(ProgramParameters::optimizer);
if (optimizer == "amsgrad")
return new dynet::AmsgradTrainer(model, ProgramParameters::learningRate, ProgramParameters::beta1, ProgramParameters::beta2, ProgramParameters::bias);
else if (optimizer == "adam")
return new dynet::AdamTrainer(model, ProgramParameters::learningRate, ProgramParameters::beta1, ProgramParameters::beta2, ProgramParameters::bias);
else if (optimizer == "sgd")
return new dynet::SimpleSGDTrainer(model, ProgramParameters::learningRate);
fprintf(stderr, "ERROR (%s) : unknown optimizer \'%s\'. Aborting.\n", ERRINFO, optimizer.c_str());
exit(1);
return nullptr;
}
void MLP::addLayerToModel(Layer & layer) void MLP::addLayerToModel(Layer & layer)
{ {
dynet::Parameter W = model.add_parameters({(unsigned)layer.output_dim, (unsigned)layer.input_dim}); dynet::Parameter W = model.add_parameters({(unsigned)layer.output_dim, (unsigned)layer.input_dim});
...@@ -389,7 +410,7 @@ void MLP::loadParameters(const std::string & filename) ...@@ -389,7 +410,7 @@ void MLP::loadParameters(const std::string & filename)
MLP::MLP(const std::string & filename) MLP::MLP(const std::string & filename)
{ {
randomSeed = ProgramParameters::seed; randomSeed = ProgramParameters::seed;
trainer.reset(new dynet::AmsgradTrainer(model, ProgramParameters::learningRate, ProgramParameters::beta1, ProgramParameters::beta2, ProgramParameters::bias)); trainer.reset(createTrainer());
initDynet(); initDynet();
trainMode = false; trainMode = false;
......
...@@ -24,6 +24,7 @@ struct ProgramParameters ...@@ -24,6 +24,7 @@ struct ProgramParameters
static std::string devFilename; static std::string devFilename;
static std::string devName; static std::string devName;
static std::string lang; static std::string lang;
static std::string optimizer;
static int nbIter; static int nbIter;
static int seed; static int seed;
static bool removeDuplicates; static bool removeDuplicates;
......
...@@ -19,6 +19,7 @@ std::string ProgramParameters::trainName; ...@@ -19,6 +19,7 @@ std::string ProgramParameters::trainName;
std::string ProgramParameters::devFilename; std::string ProgramParameters::devFilename;
std::string ProgramParameters::devName; std::string ProgramParameters::devName;
std::string ProgramParameters::lang; std::string ProgramParameters::lang;
std::string ProgramParameters::optimizer;
int ProgramParameters::nbIter; int ProgramParameters::nbIter;
int ProgramParameters::seed; int ProgramParameters::seed;
bool ProgramParameters::removeDuplicates; bool ProgramParameters::removeDuplicates;
......
...@@ -38,6 +38,8 @@ po::options_description getOptionsDescription() ...@@ -38,6 +38,8 @@ po::options_description getOptionsDescription()
opt.add_options() opt.add_options()
("help,h", "Produce this help message") ("help,h", "Produce this help message")
("debug,d", "Print infos on stderr") ("debug,d", "Print infos on stderr")
("optimizer", po::value<std::string>()->default_value("amsgrad"),
"The learning algorithm to use : amsgrad | adam | sgd")
("dev", po::value<std::string>()->default_value(""), ("dev", po::value<std::string>()->default_value(""),
"Development corpus formated according to the MCD") "Development corpus formated according to the MCD")
("lang", po::value<std::string>()->default_value("fr"), ("lang", po::value<std::string>()->default_value("fr"),
...@@ -46,12 +48,6 @@ po::options_description getOptionsDescription() ...@@ -46,12 +48,6 @@ po::options_description getOptionsDescription()
"Number of training epochs (iterations)") "Number of training epochs (iterations)")
("lr", po::value<float>()->default_value(0.001), ("lr", po::value<float>()->default_value(0.001),
"Learning rate of the optimizer") "Learning rate of the optimizer")
("b1", po::value<float>()->default_value(0.9),
"beta1 parameter for the Amsgtad or Adam optimizer")
("b2", po::value<float>()->default_value(0.999),
"beta2 parameter for the Amsgtad or Adam optimizer")
("bias", po::value<float>()->default_value(1e-8),
"bias parameter for the Amsgtad or Adam or Adagrad optimizer")
("seed,s", po::value<int>()->default_value(100), ("seed,s", po::value<int>()->default_value(100),
"The random seed that will initialize RNG") "The random seed that will initialize RNG")
("duplicates", po::value<bool>()->default_value(true), ("duplicates", po::value<bool>()->default_value(true),
...@@ -59,7 +55,16 @@ po::options_description getOptionsDescription() ...@@ -59,7 +55,16 @@ po::options_description getOptionsDescription()
("shuffle", po::value<bool>()->default_value(true), ("shuffle", po::value<bool>()->default_value(true),
"Shuffle examples after each iteration"); "Shuffle examples after each iteration");
desc.add(req).add(opt); po::options_description ams("Amsgrad family optimizers");
ams.add_options()
("b1", po::value<float>()->default_value(0.9),
"beta1 parameter for the Amsgtad or Adam optimizer")
("b2", po::value<float>()->default_value(0.999),
"beta2 parameter for the Amsgtad or Adam optimizer")
("bias", po::value<float>()->default_value(1e-8),
"bias parameter for the Amsgtad or Adam or Adagrad optimizer");
desc.add(req).add(opt).add(ams);
return desc; return desc;
} }
...@@ -128,6 +133,7 @@ int main(int argc, char * argv[]) ...@@ -128,6 +133,7 @@ int main(int argc, char * argv[])
ProgramParameters::beta1 = vm["b1"].as<float>(); ProgramParameters::beta1 = vm["b1"].as<float>();
ProgramParameters::beta2 = vm["b2"].as<float>(); ProgramParameters::beta2 = vm["b2"].as<float>();
ProgramParameters::bias = vm["bias"].as<float>(); ProgramParameters::bias = vm["bias"].as<float>();
ProgramParameters::optimizer = vm["optimizer"].as<std::string>();
const char * MACAON_DIR = std::getenv("MACAON_DIR"); const char * MACAON_DIR = std::getenv("MACAON_DIR");
std::string slash = "/"; std::string slash = "/";
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment