Skip to content
Snippets Groups Projects
Commit 4876271c authored by Franck Dary's avatar Franck Dary
Browse files

Added the nbTrain option to train the same model many times with different seeds

parent 0344090e
No related branches found
No related tags found
No related merge requests found
......@@ -37,6 +37,8 @@ class MLP
/// @brief The seed that will be used by RNG (srand and dynet)
static int randomSeed;
static bool dynetIsInit;
/// @brief Get the string corresponding to an Activation.
///
/// @param a The activation.
......
......@@ -7,6 +7,7 @@
#include <dynet/io.h>
int MLP::randomSeed = 0;
bool MLP::dynetIsInit = false;
std::string MLP::activation2str(Activation a)
{
......@@ -72,12 +73,10 @@ MLP::Activation MLP::str2activation(std::string s)
void MLP::initDynet()
{
static bool init = false;
if(init)
if(dynetIsInit)
return;
init = true;
dynetIsInit = true;
dynet::initialize(getDefaultParams());
}
......
......@@ -203,6 +203,8 @@ class Dict
/// @param directory The directory in which we will save every Dict.
/// @param namePrefix The prefix of the name of the dicts we need to save.
static void saveDicts(const std::string & directory, const std::string & namePrefix);
/// @brief Delete all Dicts.
static void deleteDicts();
/// @brief Save the current Dict in the corresponding file.
void save();
/// @brief Get the vector value of an entry.
......
......@@ -12,6 +12,7 @@ struct ProgramParameters
static std::string input;
static std::string expName;
static std::string expPath;
static std::string baseExpName;
static std::string langPath;
static std::string templatePath;
static std::string templateName;
......@@ -41,6 +42,7 @@ struct ProgramParameters
static float dynamicProbability;
static bool showFeatureRepresentation;
static int iterationSize;
static int nbTrain;
private :
......
......@@ -416,3 +416,8 @@ void Dict::printForDebug(FILE * output)
fprintf(output, "Dict name \'%s\' nbElems = %lu\n", name.c_str(), str2vec.size());
}
void Dict::deleteDicts()
{
str2dict.clear();
}
......@@ -6,6 +6,7 @@ ProgramParameters::ProgramParameters()
std::string ProgramParameters::input;
std::string ProgramParameters::expName;
std::string ProgramParameters::baseExpName;
std::string ProgramParameters::expPath;
std::string ProgramParameters::langPath;
std::string ProgramParameters::templatePath;
......@@ -36,3 +37,4 @@ int ProgramParameters::dynamicEpoch;
float ProgramParameters::dynamicProbability;
bool ProgramParameters::showFeatureRepresentation;
int ProgramParameters::iterationSize;
int ProgramParameters::nbTrain;
......@@ -54,6 +54,8 @@ po::options_description getOptionsDescription()
"Learning rate of the optimizer")
("seed,s", po::value<int>()->default_value(100),
"The random seed that will initialize RNG")
("nbTrain", po::value<int>()->default_value(0),
"The number of models that will be trained, with only the random seed changing")
("duplicates", po::value<bool>()->default_value(true),
"Remove identical training examples")
("showFeatureRepresentation", po::value<bool>()->default_value(false),
......@@ -209,6 +211,7 @@ int main(int argc, char * argv[])
po::variables_map vm = checkOptions(od, argc, argv);
ProgramParameters::expName = vm["expName"].as<std::string>();
ProgramParameters::baseExpName = ProgramParameters::expName;
ProgramParameters::templateName = vm["templateName"].as<std::string>();
ProgramParameters::tmName = vm["tm"].as<std::string>();
ProgramParameters::bdName = vm["bd"].as<std::string>();
......@@ -219,6 +222,7 @@ int main(int argc, char * argv[])
ProgramParameters::lang = vm["lang"].as<std::string>();
ProgramParameters::nbIter = vm["nbiter"].as<int>();
ProgramParameters::seed = vm["seed"].as<int>();
ProgramParameters::nbTrain = vm["nbTrain"].as<int>();
ProgramParameters::removeDuplicates = vm["duplicates"].as<bool>();
ProgramParameters::interactive = vm["interactive"].as<bool>();
ProgramParameters::shuffleExamples = vm["shuffle"].as<bool>();
......@@ -232,11 +236,23 @@ int main(int argc, char * argv[])
ProgramParameters::showFeatureRepresentation = vm["showFeatureRepresentation"].as<bool>();
ProgramParameters::iterationSize = vm["iterationSize"].as<int>();
for (int i = 0; i < 10; i++)
if (ProgramParameters::nbTrain)
{
for (int i = 0; i < ProgramParameters::nbTrain; i++)
{
fprintf(stderr, "Training number %d / %d :\n", i+1, ProgramParameters::nbTrain);
ProgramParameters::expName = ProgramParameters::baseExpName + "_" + std::to_string(i);
updatePaths();
createExpPath();
Dict::deleteDicts();
launchTraining();
}
}
else
{
ProgramParameters::expName += "_" + std::to_string(i);
updatePaths();
createExpPath();
Dict::deleteDicts();
launchTraining();
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment