diff --git a/MLP/include/MLP.hpp b/MLP/include/MLP.hpp index 484d88100e481a5e53ab8d098c5578d7bd2c7946..14b011c5b1bc0696b4b6f3d71076bd288854e93d 100644 --- a/MLP/include/MLP.hpp +++ b/MLP/include/MLP.hpp @@ -92,7 +92,7 @@ class MLP /// @brief The dynet model containing the parameters to be trained. dynet::ParameterCollection model; /// @brief The training algorithm that will be used. - std::unique_ptr<dynet::AmsgradTrainer> trainer; + std::unique_ptr<dynet::Trainer> trainer; /// @brief Whether the program is in train mode or not (only in train mode the parameters will be updated). bool trainMode; /// @brief Must the Layer dropout rate be taken into account during the computations ? Usually it is only during the training step. @@ -205,6 +205,10 @@ class MLP /// /// @param output Where the topology will be printed. void printTopology(FILE * output); + /// @brief Allocate the correct trainer type depending on the program parameters. + /// + /// @return A pointer to the newly allocated trainer. + dynet::Trainer * createTrainer(); }; #endif diff --git a/MLP/src/MLP.cpp b/MLP/src/MLP.cpp index 666d4018d421715735cc8ecfec1bdbe600f8a3ca..2bf7e2d80a440be68f99d86130a1f4e7d9db9798 100644 --- a/MLP/src/MLP.cpp +++ b/MLP/src/MLP.cpp @@ -111,7 +111,7 @@ MLP::MLP(int nbInputs, const std::string & topology, int nbOutputs) layers.emplace_back(layers.back().output_dim, nbOutputs, 0.0, Activation::LINEAR); - trainer.reset(new dynet::AmsgradTrainer(model, ProgramParameters::learningRate, ProgramParameters::beta1, ProgramParameters::beta2, ProgramParameters::bias)); + trainer.reset(createTrainer()); initDynet(); @@ -124,6 +124,27 @@ MLP::MLP(int nbInputs, const std::string & topology, int nbOutputs) addLayerToModel(layer); } +dynet::Trainer * MLP::createTrainer() +{ + if (!trainMode) + return nullptr; + + auto optimizer = noAccentLower(ProgramParameters::optimizer); + + if (optimizer == "amsgrad") + return new dynet::AmsgradTrainer(model, ProgramParameters::learningRate, ProgramParameters::beta1, ProgramParameters::beta2, ProgramParameters::bias); + else if (optimizer == "adam") + return new dynet::AdamTrainer(model, ProgramParameters::learningRate, ProgramParameters::beta1, ProgramParameters::beta2, ProgramParameters::bias); + else if (optimizer == "sgd") + return new dynet::SimpleSGDTrainer(model, ProgramParameters::learningRate); + + fprintf(stderr, "ERROR (%s) : unknown optimizer \'%s\'. Aborting.\n", ERRINFO, optimizer.c_str()); + + exit(1); + + return nullptr; +} + void MLP::addLayerToModel(Layer & layer) { dynet::Parameter W = model.add_parameters({(unsigned)layer.output_dim, (unsigned)layer.input_dim}); @@ -389,7 +410,7 @@ void MLP::loadParameters(const std::string & filename) MLP::MLP(const std::string & filename) { randomSeed = ProgramParameters::seed; - trainer.reset(new dynet::AmsgradTrainer(model, ProgramParameters::learningRate, ProgramParameters::beta1, ProgramParameters::beta2, ProgramParameters::bias)); + trainer.reset(createTrainer()); initDynet(); trainMode = false; diff --git a/maca_common/include/ProgramParameters.hpp b/maca_common/include/ProgramParameters.hpp index 80a94bc2f86c7ac7342ea21a816f1244ab07e56e..061163004797e766c901b8323ca9f8750ca7cf18 100644 --- a/maca_common/include/ProgramParameters.hpp +++ b/maca_common/include/ProgramParameters.hpp @@ -24,6 +24,7 @@ struct ProgramParameters static std::string devFilename; static std::string devName; static std::string lang; + static std::string optimizer; static int nbIter; static int seed; static bool removeDuplicates; diff --git a/maca_common/src/ProgramParameters.cpp b/maca_common/src/ProgramParameters.cpp index b649b1616b142054cf5d7029e4b182b0670dd64d..272da94f8fe5bdd5e3cf533493e32704f522b0c2 100644 --- a/maca_common/src/ProgramParameters.cpp +++ b/maca_common/src/ProgramParameters.cpp @@ -19,6 +19,7 @@ std::string ProgramParameters::trainName; std::string ProgramParameters::devFilename; std::string ProgramParameters::devName; std::string ProgramParameters::lang; +std::string ProgramParameters::optimizer; int ProgramParameters::nbIter; int ProgramParameters::seed; bool ProgramParameters::removeDuplicates; diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp index 0154530ed90ed558132212b6e62daa2073c2e6c8..d88f11c57eec7be47214da6da4914a7938abb374 100644 --- a/trainer/src/macaon_train.cpp +++ b/trainer/src/macaon_train.cpp @@ -38,6 +38,8 @@ po::options_description getOptionsDescription() opt.add_options() ("help,h", "Produce this help message") ("debug,d", "Print infos on stderr") + ("optimizer", po::value<std::string>()->default_value("amsgrad"), + "The learning algorithm to use : amsgrad | adam | sgd") ("dev", po::value<std::string>()->default_value(""), "Development corpus formated according to the MCD") ("lang", po::value<std::string>()->default_value("fr"), @@ -46,12 +48,6 @@ po::options_description getOptionsDescription() "Number of training epochs (iterations)") ("lr", po::value<float>()->default_value(0.001), "Learning rate of the optimizer") - ("b1", po::value<float>()->default_value(0.9), - "beta1 parameter for the Amsgtad or Adam optimizer") - ("b2", po::value<float>()->default_value(0.999), - "beta2 parameter for the Amsgtad or Adam optimizer") - ("bias", po::value<float>()->default_value(1e-8), - "bias parameter for the Amsgtad or Adam or Adagrad optimizer") ("seed,s", po::value<int>()->default_value(100), "The random seed that will initialize RNG") ("duplicates", po::value<bool>()->default_value(true), @@ -59,7 +55,16 @@ po::options_description getOptionsDescription() ("shuffle", po::value<bool>()->default_value(true), "Shuffle examples after each iteration"); - desc.add(req).add(opt); + po::options_description ams("Amsgrad family optimizers"); + ams.add_options() + ("b1", po::value<float>()->default_value(0.9), + "beta1 parameter for the Amsgtad or Adam optimizer") + ("b2", po::value<float>()->default_value(0.999), + "beta2 parameter for the Amsgtad or Adam optimizer") + ("bias", po::value<float>()->default_value(1e-8), + "bias parameter for the Amsgtad or Adam or Adagrad optimizer"); + + desc.add(req).add(opt).add(ams); return desc; } @@ -128,6 +133,7 @@ int main(int argc, char * argv[]) ProgramParameters::beta1 = vm["b1"].as<float>(); ProgramParameters::beta2 = vm["b2"].as<float>(); ProgramParameters::bias = vm["bias"].as<float>(); + ProgramParameters::optimizer = vm["optimizer"].as<std::string>(); const char * MACAON_DIR = std::getenv("MACAON_DIR"); std::string slash = "/";