Skip to content
Snippets Groups Projects
Commit cfae9a32 authored by Franck Dary's avatar Franck Dary
Browse files

Added an option to chose between random parameters and 1-initialized parameters

parent 0ca30a51
No related branches found
No related tags found
No related merge requests found
......@@ -149,6 +149,11 @@ void MLP::addLayerToModel(Layer & layer)
{
dynet::Parameter W = model.add_parameters({(unsigned)layer.output_dim, (unsigned)layer.input_dim});
dynet::Parameter b = model.add_parameters({(unsigned)layer.output_dim});
if (!ProgramParameters::randomParameters)
{
W.set_value(std::vector<float>((unsigned)layer.output_dim * (unsigned)layer.input_dim, 1.0));
b.set_value(std::vector<float>((unsigned)layer.output_dim, 1.0));
}
parameters.push_back({W,b});
}
......
......@@ -67,6 +67,8 @@ po::options_description getOptionsDescription()
"Is the shell interactive ? Display advancement informations")
("randomEmbeddings", po::value<bool>()->default_value(false),
"When activated, the embeddings will be randomly initialized")
("randomParameters", po::value<bool>()->default_value(true),
"When activated, the parameters will be randomly initialized")
("sequenceDelimiterTape", po::value<std::string>()->default_value("EOS"),
"The name of the buffer's tape that contains the delimiter token for a sequence")
("sequenceDelimiter", po::value<std::string>()->default_value("1"),
......@@ -510,6 +512,7 @@ int main(int argc, char * argv[])
ProgramParameters::interactive = vm["interactive"].as<bool>();
ProgramParameters::shuffleExamples = vm["shuffle"].as<bool>();
ProgramParameters::randomEmbeddings = vm["randomEmbeddings"].as<bool>();
ProgramParameters::randomParameters = vm["randomParameters"].as<bool>();
ProgramParameters::sequenceDelimiterTape = vm["sequenceDelimiterTape"].as<std::string>();
ProgramParameters::sequenceDelimiter = vm["sequenceDelimiter"].as<std::string>();
ProgramParameters::learningRate = vm["lr"].as<float>();
......
......@@ -45,6 +45,7 @@ struct ProgramParameters
static int iterationSize;
static int nbTrain;
static bool randomEmbeddings;
static bool randomParameters;
static bool printEntropy;
static bool printTime;
static std::string sequenceDelimiterTape;
......
......@@ -38,6 +38,7 @@ int ProgramParameters::dynamicEpoch;
float ProgramParameters::dynamicProbability;
int ProgramParameters::showFeatureRepresentation;
bool ProgramParameters::randomEmbeddings;
bool ProgramParameters::randomParameters;
bool ProgramParameters::printEntropy;
bool ProgramParameters::printTime;
int ProgramParameters::iterationSize;
......
......@@ -210,6 +210,12 @@ void Trainer::train()
if (zeroCostActions.empty())
{
if (trainConfig.head >= (int)trainConfig.tapes[0].ref.size()-1)
{
while (!trainConfig.stackEmpty())
trainConfig.stackPop();
break;
}
fprintf(stderr, "ERROR (%s) : Unable to find any zero cost action. Aborting.\n", ERRINFO);
fprintf(stderr, "State : %s\n", currentState->name.c_str());
trainConfig.printForDebug(stderr);
......
......@@ -65,6 +65,8 @@ po::options_description getOptionsDescription()
"Is the shell interactive ? Display advancement informations")
("randomEmbeddings", po::value<bool>()->default_value(false),
"When activated, the embeddings will be randomly initialized")
("randomParameters", po::value<bool>()->default_value(true),
"When activated, the parameters will be randomly initialized")
("sequenceDelimiterTape", po::value<std::string>()->default_value("EOS"),
"The name of the buffer's tape that contains the delimiter token for a sequence")
("sequenceDelimiter", po::value<std::string>()->default_value("1"),
......@@ -253,6 +255,7 @@ int main(int argc, char * argv[])
ProgramParameters::interactive = vm["interactive"].as<bool>();
ProgramParameters::shuffleExamples = vm["shuffle"].as<bool>();
ProgramParameters::randomEmbeddings = vm["randomEmbeddings"].as<bool>();
ProgramParameters::randomParameters = vm["randomParameters"].as<bool>();
ProgramParameters::sequenceDelimiterTape = vm["sequenceDelimiterTape"].as<std::string>();
ProgramParameters::sequenceDelimiter = vm["sequenceDelimiter"].as<std::string>();
ProgramParameters::learningRate = vm["lr"].as<float>();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment