Skip to content
Snippets Groups Projects
Commit 007ae6d7 authored by Franck Dary's avatar Franck Dary
Browse files

Made the random seed a parameter

parent c7a4d3e0
No related branches found
No related tags found
No related merge requests found
...@@ -37,6 +37,9 @@ class MLP ...@@ -37,6 +37,9 @@ class MLP
SOFTMAX SOFTMAX
}; };
/// @brief The seed that will be used by RNG (srand and dynet)
static int randomSeed;
/// @brief Get the string corresponding to an Activation. /// @brief Get the string corresponding to an Activation.
/// ///
/// @param a The activation. /// @param a The activation.
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include <dynet/param-init.h> #include <dynet/param-init.h>
#include <dynet/io.h> #include <dynet/io.h>
int MLP::randomSeed = 0;
std::string MLP::activation2str(Activation a) std::string MLP::activation2str(Activation a)
{ {
switch(a) switch(a)
...@@ -175,7 +177,7 @@ std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd) ...@@ -175,7 +177,7 @@ std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd)
dynet::DynetParams & MLP::getDefaultParams() dynet::DynetParams & MLP::getDefaultParams()
{ {
static dynet::DynetParams params; static dynet::DynetParams params;
params.random_seed = 100; params.random_seed = randomSeed;
std::srand(params.random_seed); std::srand(params.random_seed);
......
...@@ -44,6 +44,8 @@ po::options_description getOptionsDescription() ...@@ -44,6 +44,8 @@ po::options_description getOptionsDescription()
"Number of training epochs (iterations)") "Number of training epochs (iterations)")
("batchsize,b", po::value<int>()->default_value(256), ("batchsize,b", po::value<int>()->default_value(256),
"Size of each training batch (in number of examples)") "Size of each training batch (in number of examples)")
("seed,s", po::value<int>()->default_value(100),
"The random seed that will initialize RNG")
("shuffle", po::value<bool>()->default_value(true), ("shuffle", po::value<bool>()->default_value(true),
"Shuffle examples after each iteration"); "Shuffle examples after each iteration");
...@@ -109,6 +111,7 @@ int main(int argc, char * argv[]) ...@@ -109,6 +111,7 @@ int main(int argc, char * argv[])
std::string lang = vm["lang"].as<std::string>(); std::string lang = vm["lang"].as<std::string>();
int nbIter = vm["nbiter"].as<int>(); int nbIter = vm["nbiter"].as<int>();
int batchSize = vm["batchsize"].as<int>(); int batchSize = vm["batchsize"].as<int>();
int randomSeed = vm["seed"].as<int>();
bool mustShuffle = vm["shuffle"].as<bool>(); bool mustShuffle = vm["shuffle"].as<bool>();
const char * MACAON_DIR = std::getenv("MACAON_DIR"); const char * MACAON_DIR = std::getenv("MACAON_DIR");
...@@ -121,6 +124,9 @@ int main(int argc, char * argv[]) ...@@ -121,6 +124,9 @@ int main(int argc, char * argv[])
trainFilename = expPath + trainFilename; trainFilename = expPath + trainFilename;
devFilename = expPath + devFilename; devFilename = expPath + devFilename;
// Setting the random seed
MLP::randomSeed = randomSeed;
TransitionMachine tapeMachine(tmFilename, true, expPath); TransitionMachine tapeMachine(tmFilename, true, expPath);
BD trainBD(BDfilename, MCDfilename); BD trainBD(BDfilename, MCDfilename);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment