Skip to content
Snippets Groups Projects
Commit 6ba54eae authored by Franck Dary's avatar Franck Dary
Browse files

Added dropout support

parent 374bc8d7
Branches
No related tags found
No related merge requests found
......@@ -52,6 +52,7 @@ class MLP
dynet::ParameterCollection model;
dynet::AmsgradTrainer trainer;
bool trainMode;
bool dropoutActive;
private :
......
......@@ -68,6 +68,7 @@ MLP::MLP(std::vector<Layer> layers)
dynet::initialize(getDefaultParams());
trainMode = true;
dropoutActive = true;
checkLayersCompatibility();
......@@ -109,6 +110,8 @@ MLP::Layer::Layer(int input_dim, int output_dim,
std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd)
{
bool currentDropoutActive = dropoutActive;
dropoutActive = false;
dynet::ComputationGraph cg;
std::vector<dynet::Expression> expressions;
......@@ -120,6 +123,8 @@ std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd)
dynet::Expression output = run(cg, input);
dropoutActive = currentDropoutActive;
return as_vector(cg.forward(output));
}
......@@ -193,14 +198,12 @@ dynet::Expression MLP::run(dynet::ComputationGraph & cg, dynet::Expression x)
dynet::Expression a = dynet::affine_transform({b, W, h_cur});
// Apply activation function
dynet::Expression h = activate(a, layers[l].activation);
h_cur = h;
// Take care of dropout
/*
dynet::Expression h_dropped;
if(layers[l].dropout_rate > 0){
if(dropout_active){
if(dropoutActive){
dynet::Expression mask = random_bernoulli(cg,
{layers[l].output_dim}, 1 - layers[l].dropout_rate);
{(unsigned int)layers[l].output_dim}, 1 - layers[l].dropout_rate);
h_dropped = cmult(h, mask);
}
else{
......@@ -212,7 +215,6 @@ dynet::Expression MLP::run(dynet::ComputationGraph & cg, dynet::Expression x)
}
h_cur = h_dropped;
*/
}
return h_cur;
......@@ -309,6 +311,9 @@ int MLP::trainOnBatch(Examples & examples, int start, int end)
int MLP::getScoreOnBatch(Examples & examples, int start, int end)
{
bool currentDropoutActive = dropoutActive;
dropoutActive = false;
dynet::ComputationGraph cg;
std::vector<dynet::Expression> inputs;
std::vector<unsigned int> goldClasses;
......@@ -354,6 +359,8 @@ int MLP::getScoreOnBatch(Examples & examples, int start, int end)
nbCorrect++;
}
dropoutActive = currentDropoutActive;
return nbCorrect;
}
......@@ -429,6 +436,7 @@ MLP::MLP(const std::string & filename)
dynet::initialize(getDefaultParams());
trainMode = false;
dropoutActive = false;
load(filename);
}
......
......@@ -133,7 +133,7 @@ void Classifier::initClassifier(Config & config)
for (auto feat : fd.values)
nbInputs += feat.vec->size();
mlp.reset(new MLP({{nbInputs, nbHidden, 0.0, MLP::Activation::RELU},
mlp.reset(new MLP({{nbInputs, nbHidden, 0.3, MLP::Activation::RELU},
{nbHidden, nbOutputs, 0.0, MLP::Activation::LINEAR}}));
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment