Skip to content
Snippets Groups Projects
Commit 4876271c authored by Franck Dary's avatar Franck Dary
Browse files

Added the nbTrain option to train the same model many times with different seeds

parent 0344090e
Branches
Tags
No related merge requests found
...@@ -37,6 +37,8 @@ class MLP ...@@ -37,6 +37,8 @@ class MLP
/// @brief The seed that will be used by RNG (srand and dynet) /// @brief The seed that will be used by RNG (srand and dynet)
static int randomSeed; static int randomSeed;
static bool dynetIsInit;
/// @brief Get the string corresponding to an Activation. /// @brief Get the string corresponding to an Activation.
/// ///
/// @param a The activation. /// @param a The activation.
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <dynet/io.h> #include <dynet/io.h>
int MLP::randomSeed = 0; int MLP::randomSeed = 0;
bool MLP::dynetIsInit = false;
std::string MLP::activation2str(Activation a) std::string MLP::activation2str(Activation a)
{ {
...@@ -72,12 +73,10 @@ MLP::Activation MLP::str2activation(std::string s) ...@@ -72,12 +73,10 @@ MLP::Activation MLP::str2activation(std::string s)
void MLP::initDynet() void MLP::initDynet()
{ {
static bool init = false; if(dynetIsInit)
if(init)
return; return;
init = true; dynetIsInit = true;
dynet::initialize(getDefaultParams()); dynet::initialize(getDefaultParams());
} }
......
...@@ -203,6 +203,8 @@ class Dict ...@@ -203,6 +203,8 @@ class Dict
/// @param directory The directory in which we will save every Dict. /// @param directory The directory in which we will save every Dict.
/// @param namePrefix The prefix of the name of the dicts we need to save. /// @param namePrefix The prefix of the name of the dicts we need to save.
static void saveDicts(const std::string & directory, const std::string & namePrefix); static void saveDicts(const std::string & directory, const std::string & namePrefix);
/// @brief Delete all Dicts.
static void deleteDicts();
/// @brief Save the current Dict in the corresponding file. /// @brief Save the current Dict in the corresponding file.
void save(); void save();
/// @brief Get the vector value of an entry. /// @brief Get the vector value of an entry.
......
...@@ -12,6 +12,7 @@ struct ProgramParameters ...@@ -12,6 +12,7 @@ struct ProgramParameters
static std::string input; static std::string input;
static std::string expName; static std::string expName;
static std::string expPath; static std::string expPath;
static std::string baseExpName;
static std::string langPath; static std::string langPath;
static std::string templatePath; static std::string templatePath;
static std::string templateName; static std::string templateName;
...@@ -41,6 +42,7 @@ struct ProgramParameters ...@@ -41,6 +42,7 @@ struct ProgramParameters
static float dynamicProbability; static float dynamicProbability;
static bool showFeatureRepresentation; static bool showFeatureRepresentation;
static int iterationSize; static int iterationSize;
static int nbTrain;
private : private :
......
...@@ -416,3 +416,8 @@ void Dict::printForDebug(FILE * output) ...@@ -416,3 +416,8 @@ void Dict::printForDebug(FILE * output)
fprintf(output, "Dict name \'%s\' nbElems = %lu\n", name.c_str(), str2vec.size()); fprintf(output, "Dict name \'%s\' nbElems = %lu\n", name.c_str(), str2vec.size());
} }
void Dict::deleteDicts()
{
str2dict.clear();
}
...@@ -6,6 +6,7 @@ ProgramParameters::ProgramParameters() ...@@ -6,6 +6,7 @@ ProgramParameters::ProgramParameters()
std::string ProgramParameters::input; std::string ProgramParameters::input;
std::string ProgramParameters::expName; std::string ProgramParameters::expName;
std::string ProgramParameters::baseExpName;
std::string ProgramParameters::expPath; std::string ProgramParameters::expPath;
std::string ProgramParameters::langPath; std::string ProgramParameters::langPath;
std::string ProgramParameters::templatePath; std::string ProgramParameters::templatePath;
...@@ -36,3 +37,4 @@ int ProgramParameters::dynamicEpoch; ...@@ -36,3 +37,4 @@ int ProgramParameters::dynamicEpoch;
float ProgramParameters::dynamicProbability; float ProgramParameters::dynamicProbability;
bool ProgramParameters::showFeatureRepresentation; bool ProgramParameters::showFeatureRepresentation;
int ProgramParameters::iterationSize; int ProgramParameters::iterationSize;
int ProgramParameters::nbTrain;
...@@ -54,6 +54,8 @@ po::options_description getOptionsDescription() ...@@ -54,6 +54,8 @@ po::options_description getOptionsDescription()
"Learning rate of the optimizer") "Learning rate of the optimizer")
("seed,s", po::value<int>()->default_value(100), ("seed,s", po::value<int>()->default_value(100),
"The random seed that will initialize RNG") "The random seed that will initialize RNG")
("nbTrain", po::value<int>()->default_value(0),
"The number of models that will be trained, with only the random seed changing")
("duplicates", po::value<bool>()->default_value(true), ("duplicates", po::value<bool>()->default_value(true),
"Remove identical training examples") "Remove identical training examples")
("showFeatureRepresentation", po::value<bool>()->default_value(false), ("showFeatureRepresentation", po::value<bool>()->default_value(false),
...@@ -209,6 +211,7 @@ int main(int argc, char * argv[]) ...@@ -209,6 +211,7 @@ int main(int argc, char * argv[])
po::variables_map vm = checkOptions(od, argc, argv); po::variables_map vm = checkOptions(od, argc, argv);
ProgramParameters::expName = vm["expName"].as<std::string>(); ProgramParameters::expName = vm["expName"].as<std::string>();
ProgramParameters::baseExpName = ProgramParameters::expName;
ProgramParameters::templateName = vm["templateName"].as<std::string>(); ProgramParameters::templateName = vm["templateName"].as<std::string>();
ProgramParameters::tmName = vm["tm"].as<std::string>(); ProgramParameters::tmName = vm["tm"].as<std::string>();
ProgramParameters::bdName = vm["bd"].as<std::string>(); ProgramParameters::bdName = vm["bd"].as<std::string>();
...@@ -219,6 +222,7 @@ int main(int argc, char * argv[]) ...@@ -219,6 +222,7 @@ int main(int argc, char * argv[])
ProgramParameters::lang = vm["lang"].as<std::string>(); ProgramParameters::lang = vm["lang"].as<std::string>();
ProgramParameters::nbIter = vm["nbiter"].as<int>(); ProgramParameters::nbIter = vm["nbiter"].as<int>();
ProgramParameters::seed = vm["seed"].as<int>(); ProgramParameters::seed = vm["seed"].as<int>();
ProgramParameters::nbTrain = vm["nbTrain"].as<int>();
ProgramParameters::removeDuplicates = vm["duplicates"].as<bool>(); ProgramParameters::removeDuplicates = vm["duplicates"].as<bool>();
ProgramParameters::interactive = vm["interactive"].as<bool>(); ProgramParameters::interactive = vm["interactive"].as<bool>();
ProgramParameters::shuffleExamples = vm["shuffle"].as<bool>(); ProgramParameters::shuffleExamples = vm["shuffle"].as<bool>();
...@@ -232,11 +236,23 @@ int main(int argc, char * argv[]) ...@@ -232,11 +236,23 @@ int main(int argc, char * argv[])
ProgramParameters::showFeatureRepresentation = vm["showFeatureRepresentation"].as<bool>(); ProgramParameters::showFeatureRepresentation = vm["showFeatureRepresentation"].as<bool>();
ProgramParameters::iterationSize = vm["iterationSize"].as<int>(); ProgramParameters::iterationSize = vm["iterationSize"].as<int>();
for (int i = 0; i < 10; i++) if (ProgramParameters::nbTrain)
{
for (int i = 0; i < ProgramParameters::nbTrain; i++)
{
fprintf(stderr, "Training number %d / %d :\n", i+1, ProgramParameters::nbTrain);
ProgramParameters::expName = ProgramParameters::baseExpName + "_" + std::to_string(i);
updatePaths();
createExpPath();
Dict::deleteDicts();
launchTraining();
}
}
else
{ {
ProgramParameters::expName += "_" + std::to_string(i);
updatePaths(); updatePaths();
createExpPath(); createExpPath();
Dict::deleteDicts();
launchTraining(); launchTraining();
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment