Skip to content
Snippets Groups Projects
Commit 6ba54eae authored by Franck Dary's avatar Franck Dary
Browse files

Added dropout support

parent 374bc8d7
Branches
No related tags found
No related merge requests found
...@@ -52,6 +52,7 @@ class MLP ...@@ -52,6 +52,7 @@ class MLP
dynet::ParameterCollection model; dynet::ParameterCollection model;
dynet::AmsgradTrainer trainer; dynet::AmsgradTrainer trainer;
bool trainMode; bool trainMode;
bool dropoutActive;
private : private :
......
...@@ -68,6 +68,7 @@ MLP::MLP(std::vector<Layer> layers) ...@@ -68,6 +68,7 @@ MLP::MLP(std::vector<Layer> layers)
dynet::initialize(getDefaultParams()); dynet::initialize(getDefaultParams());
trainMode = true; trainMode = true;
dropoutActive = true;
checkLayersCompatibility(); checkLayersCompatibility();
...@@ -109,6 +110,8 @@ MLP::Layer::Layer(int input_dim, int output_dim, ...@@ -109,6 +110,8 @@ MLP::Layer::Layer(int input_dim, int output_dim,
std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd) std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd)
{ {
bool currentDropoutActive = dropoutActive;
dropoutActive = false;
dynet::ComputationGraph cg; dynet::ComputationGraph cg;
std::vector<dynet::Expression> expressions; std::vector<dynet::Expression> expressions;
...@@ -120,6 +123,8 @@ std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd) ...@@ -120,6 +123,8 @@ std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd)
dynet::Expression output = run(cg, input); dynet::Expression output = run(cg, input);
dropoutActive = currentDropoutActive;
return as_vector(cg.forward(output)); return as_vector(cg.forward(output));
} }
...@@ -193,14 +198,12 @@ dynet::Expression MLP::run(dynet::ComputationGraph & cg, dynet::Expression x) ...@@ -193,14 +198,12 @@ dynet::Expression MLP::run(dynet::ComputationGraph & cg, dynet::Expression x)
dynet::Expression a = dynet::affine_transform({b, W, h_cur}); dynet::Expression a = dynet::affine_transform({b, W, h_cur});
// Apply activation function // Apply activation function
dynet::Expression h = activate(a, layers[l].activation); dynet::Expression h = activate(a, layers[l].activation);
h_cur = h;
// Take care of dropout // Take care of dropout
/*
dynet::Expression h_dropped; dynet::Expression h_dropped;
if(layers[l].dropout_rate > 0){ if(layers[l].dropout_rate > 0){
if(dropout_active){ if(dropoutActive){
dynet::Expression mask = random_bernoulli(cg, dynet::Expression mask = random_bernoulli(cg,
{layers[l].output_dim}, 1 - layers[l].dropout_rate); {(unsigned int)layers[l].output_dim}, 1 - layers[l].dropout_rate);
h_dropped = cmult(h, mask); h_dropped = cmult(h, mask);
} }
else{ else{
...@@ -212,7 +215,6 @@ dynet::Expression MLP::run(dynet::ComputationGraph & cg, dynet::Expression x) ...@@ -212,7 +215,6 @@ dynet::Expression MLP::run(dynet::ComputationGraph & cg, dynet::Expression x)
} }
h_cur = h_dropped; h_cur = h_dropped;
*/
} }
return h_cur; return h_cur;
...@@ -309,6 +311,9 @@ int MLP::trainOnBatch(Examples & examples, int start, int end) ...@@ -309,6 +311,9 @@ int MLP::trainOnBatch(Examples & examples, int start, int end)
int MLP::getScoreOnBatch(Examples & examples, int start, int end) int MLP::getScoreOnBatch(Examples & examples, int start, int end)
{ {
bool currentDropoutActive = dropoutActive;
dropoutActive = false;
dynet::ComputationGraph cg; dynet::ComputationGraph cg;
std::vector<dynet::Expression> inputs; std::vector<dynet::Expression> inputs;
std::vector<unsigned int> goldClasses; std::vector<unsigned int> goldClasses;
...@@ -354,6 +359,8 @@ int MLP::getScoreOnBatch(Examples & examples, int start, int end) ...@@ -354,6 +359,8 @@ int MLP::getScoreOnBatch(Examples & examples, int start, int end)
nbCorrect++; nbCorrect++;
} }
dropoutActive = currentDropoutActive;
return nbCorrect; return nbCorrect;
} }
...@@ -429,6 +436,7 @@ MLP::MLP(const std::string & filename) ...@@ -429,6 +436,7 @@ MLP::MLP(const std::string & filename)
dynet::initialize(getDefaultParams()); dynet::initialize(getDefaultParams());
trainMode = false; trainMode = false;
dropoutActive = false;
load(filename); load(filename);
} }
......
...@@ -133,7 +133,7 @@ void Classifier::initClassifier(Config & config) ...@@ -133,7 +133,7 @@ void Classifier::initClassifier(Config & config)
for (auto feat : fd.values) for (auto feat : fd.values)
nbInputs += feat.vec->size(); nbInputs += feat.vec->size();
mlp.reset(new MLP({{nbInputs, nbHidden, 0.0, MLP::Activation::RELU}, mlp.reset(new MLP({{nbInputs, nbHidden, 0.3, MLP::Activation::RELU},
{nbHidden, nbOutputs, 0.0, MLP::Activation::LINEAR}})); {nbHidden, nbOutputs, 0.0, MLP::Activation::LINEAR}}));
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment