Skip to content
Snippets Groups Projects
Commit cfae9a32 authored by Franck Dary's avatar Franck Dary
Browse files

Added an option to chose between random parameters and 1-initialized parameters

parent 0ca30a51
No related branches found
No related tags found
No related merge requests found
...@@ -149,6 +149,11 @@ void MLP::addLayerToModel(Layer & layer) ...@@ -149,6 +149,11 @@ void MLP::addLayerToModel(Layer & layer)
{ {
dynet::Parameter W = model.add_parameters({(unsigned)layer.output_dim, (unsigned)layer.input_dim}); dynet::Parameter W = model.add_parameters({(unsigned)layer.output_dim, (unsigned)layer.input_dim});
dynet::Parameter b = model.add_parameters({(unsigned)layer.output_dim}); dynet::Parameter b = model.add_parameters({(unsigned)layer.output_dim});
if (!ProgramParameters::randomParameters)
{
W.set_value(std::vector<float>((unsigned)layer.output_dim * (unsigned)layer.input_dim, 1.0));
b.set_value(std::vector<float>((unsigned)layer.output_dim, 1.0));
}
parameters.push_back({W,b}); parameters.push_back({W,b});
} }
......
...@@ -67,6 +67,8 @@ po::options_description getOptionsDescription() ...@@ -67,6 +67,8 @@ po::options_description getOptionsDescription()
"Is the shell interactive ? Display advancement informations") "Is the shell interactive ? Display advancement informations")
("randomEmbeddings", po::value<bool>()->default_value(false), ("randomEmbeddings", po::value<bool>()->default_value(false),
"When activated, the embeddings will be randomly initialized") "When activated, the embeddings will be randomly initialized")
("randomParameters", po::value<bool>()->default_value(true),
"When activated, the parameters will be randomly initialized")
("sequenceDelimiterTape", po::value<std::string>()->default_value("EOS"), ("sequenceDelimiterTape", po::value<std::string>()->default_value("EOS"),
"The name of the buffer's tape that contains the delimiter token for a sequence") "The name of the buffer's tape that contains the delimiter token for a sequence")
("sequenceDelimiter", po::value<std::string>()->default_value("1"), ("sequenceDelimiter", po::value<std::string>()->default_value("1"),
...@@ -510,6 +512,7 @@ int main(int argc, char * argv[]) ...@@ -510,6 +512,7 @@ int main(int argc, char * argv[])
ProgramParameters::interactive = vm["interactive"].as<bool>(); ProgramParameters::interactive = vm["interactive"].as<bool>();
ProgramParameters::shuffleExamples = vm["shuffle"].as<bool>(); ProgramParameters::shuffleExamples = vm["shuffle"].as<bool>();
ProgramParameters::randomEmbeddings = vm["randomEmbeddings"].as<bool>(); ProgramParameters::randomEmbeddings = vm["randomEmbeddings"].as<bool>();
ProgramParameters::randomParameters = vm["randomParameters"].as<bool>();
ProgramParameters::sequenceDelimiterTape = vm["sequenceDelimiterTape"].as<std::string>(); ProgramParameters::sequenceDelimiterTape = vm["sequenceDelimiterTape"].as<std::string>();
ProgramParameters::sequenceDelimiter = vm["sequenceDelimiter"].as<std::string>(); ProgramParameters::sequenceDelimiter = vm["sequenceDelimiter"].as<std::string>();
ProgramParameters::learningRate = vm["lr"].as<float>(); ProgramParameters::learningRate = vm["lr"].as<float>();
......
...@@ -45,6 +45,7 @@ struct ProgramParameters ...@@ -45,6 +45,7 @@ struct ProgramParameters
static int iterationSize; static int iterationSize;
static int nbTrain; static int nbTrain;
static bool randomEmbeddings; static bool randomEmbeddings;
static bool randomParameters;
static bool printEntropy; static bool printEntropy;
static bool printTime; static bool printTime;
static std::string sequenceDelimiterTape; static std::string sequenceDelimiterTape;
......
...@@ -38,6 +38,7 @@ int ProgramParameters::dynamicEpoch; ...@@ -38,6 +38,7 @@ int ProgramParameters::dynamicEpoch;
float ProgramParameters::dynamicProbability; float ProgramParameters::dynamicProbability;
int ProgramParameters::showFeatureRepresentation; int ProgramParameters::showFeatureRepresentation;
bool ProgramParameters::randomEmbeddings; bool ProgramParameters::randomEmbeddings;
bool ProgramParameters::randomParameters;
bool ProgramParameters::printEntropy; bool ProgramParameters::printEntropy;
bool ProgramParameters::printTime; bool ProgramParameters::printTime;
int ProgramParameters::iterationSize; int ProgramParameters::iterationSize;
......
...@@ -210,6 +210,12 @@ void Trainer::train() ...@@ -210,6 +210,12 @@ void Trainer::train()
if (zeroCostActions.empty()) if (zeroCostActions.empty())
{ {
if (trainConfig.head >= (int)trainConfig.tapes[0].ref.size()-1)
{
while (!trainConfig.stackEmpty())
trainConfig.stackPop();
break;
}
fprintf(stderr, "ERROR (%s) : Unable to find any zero cost action. Aborting.\n", ERRINFO); fprintf(stderr, "ERROR (%s) : Unable to find any zero cost action. Aborting.\n", ERRINFO);
fprintf(stderr, "State : %s\n", currentState->name.c_str()); fprintf(stderr, "State : %s\n", currentState->name.c_str());
trainConfig.printForDebug(stderr); trainConfig.printForDebug(stderr);
......
...@@ -65,6 +65,8 @@ po::options_description getOptionsDescription() ...@@ -65,6 +65,8 @@ po::options_description getOptionsDescription()
"Is the shell interactive ? Display advancement informations") "Is the shell interactive ? Display advancement informations")
("randomEmbeddings", po::value<bool>()->default_value(false), ("randomEmbeddings", po::value<bool>()->default_value(false),
"When activated, the embeddings will be randomly initialized") "When activated, the embeddings will be randomly initialized")
("randomParameters", po::value<bool>()->default_value(true),
"When activated, the parameters will be randomly initialized")
("sequenceDelimiterTape", po::value<std::string>()->default_value("EOS"), ("sequenceDelimiterTape", po::value<std::string>()->default_value("EOS"),
"The name of the buffer's tape that contains the delimiter token for a sequence") "The name of the buffer's tape that contains the delimiter token for a sequence")
("sequenceDelimiter", po::value<std::string>()->default_value("1"), ("sequenceDelimiter", po::value<std::string>()->default_value("1"),
...@@ -253,6 +255,7 @@ int main(int argc, char * argv[]) ...@@ -253,6 +255,7 @@ int main(int argc, char * argv[])
ProgramParameters::interactive = vm["interactive"].as<bool>(); ProgramParameters::interactive = vm["interactive"].as<bool>();
ProgramParameters::shuffleExamples = vm["shuffle"].as<bool>(); ProgramParameters::shuffleExamples = vm["shuffle"].as<bool>();
ProgramParameters::randomEmbeddings = vm["randomEmbeddings"].as<bool>(); ProgramParameters::randomEmbeddings = vm["randomEmbeddings"].as<bool>();
ProgramParameters::randomParameters = vm["randomParameters"].as<bool>();
ProgramParameters::sequenceDelimiterTape = vm["sequenceDelimiterTape"].as<std::string>(); ProgramParameters::sequenceDelimiterTape = vm["sequenceDelimiterTape"].as<std::string>();
ProgramParameters::sequenceDelimiter = vm["sequenceDelimiter"].as<std::string>(); ProgramParameters::sequenceDelimiter = vm["sequenceDelimiter"].as<std::string>();
ProgramParameters::learningRate = vm["lr"].as<float>(); ProgramParameters::learningRate = vm["lr"].as<float>();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment