Skip to content
Snippets Groups Projects
Commit a75a3935 authored by Franck Dary's avatar Franck Dary
Browse files

MCD error messages and MLP predict

parent d2013f81
No related branches found
No related tags found
No related merge requests found
...@@ -43,6 +43,7 @@ class MLP ...@@ -43,6 +43,7 @@ class MLP
std::vector<Layer> layers; std::vector<Layer> layers;
std::vector< std::vector<dynet::Parameter> > parameters; std::vector< std::vector<dynet::Parameter> > parameters;
std::map<void*,dynet::Parameter> ptr2parameter;
dynet::ParameterCollection model; dynet::ParameterCollection model;
dynet::AmsgradTrainer trainer; dynet::AmsgradTrainer trainer;
...@@ -53,6 +54,9 @@ class MLP ...@@ -53,6 +54,9 @@ class MLP
void addLayerToModel(Layer & layer); void addLayerToModel(Layer & layer);
void checkLayersCompatibility(); void checkLayersCompatibility();
dynet::DynetParams & getDefaultParams(); dynet::DynetParams & getDefaultParams();
dynet::Parameter & featValue2parameter(const FeatureModel::FeatureValue & fv);
dynet::Expression run(dynet::ComputationGraph & cg, dynet::Expression x);
inline dynet::Expression activate(dynet::Expression h, Activation f);
public : public :
......
#include "MLP.hpp" #include "MLP.hpp"
#include "util.hpp" #include "util.hpp"
#include <dynet/param-init.h>
std::string MLP::activation2str(Activation a) std::string MLP::activation2str(Activation a)
{ {
switch(a) switch(a)
...@@ -107,11 +109,19 @@ std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd, int goldC ...@@ -107,11 +109,19 @@ std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd, int goldC
{ {
dynet::ComputationGraph cg; dynet::ComputationGraph cg;
std::vector<dynet::Expression> expressions;
for (auto & featValue : fd.values) for (auto & featValue : fd.values)
{ {
dynet::Parameter p(*featValue.vec); if(featValue.policy == FeatureModel::Policy::Final)
expressions.emplace_back(dynet::const_parameter(cg, featValue2parameter(featValue)));
else
expressions.emplace_back(dynet::parameter(cg, featValue2parameter(featValue)));
} }
dynet::Expression input = dynet::concatenate(expressions);
dynet::Expression output = run(cg, input);
/* /*
int nbInputs = layers[0].input_dim; int nbInputs = layers[0].input_dim;
...@@ -126,7 +136,7 @@ std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd, int goldC ...@@ -126,7 +136,7 @@ std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd, int goldC
trainer.update(); trainer.update();
*/ */
std::vector<float> res; std::vector<float> res = as_vector(cg.forward(output));
return res; return res;
} }
...@@ -141,3 +151,82 @@ dynet::DynetParams & MLP::getDefaultParams() ...@@ -141,3 +151,82 @@ dynet::DynetParams & MLP::getDefaultParams()
return params; return params;
} }
dynet::Parameter & MLP::featValue2parameter(const FeatureModel::FeatureValue & fv)
{
auto it = ptr2parameter.find(fv.vec);
if(it != ptr2parameter.end())
return it->second;
ptr2parameter[fv.vec] = model.add_parameters({1,fv.vec->size()}, dynet::ParameterInitFromVector(*fv.vec));
it = ptr2parameter.find(fv.vec);
it->second.set_updated(fv.policy == FeatureModel::Policy::Final ? false : true);
return it->second;
}
dynet::Expression MLP::run(dynet::ComputationGraph & cg, dynet::Expression x)
{
// Expression for the current hidden state
dynet::Expression h_cur = x;
for(unsigned int l = 0; l < layers.size(); l++)
{
// Initialize parameters in computation graph
dynet::Expression W = parameter(cg, parameters[l][0]);
dynet::Expression b = parameter(cg, parameters[l][1]);
// Apply affine transform
dynet::Expression a = dynet::affine_transform({b, W, h_cur});
// Apply activation function
dynet::Expression h = activate(a, layers[l].activation);
h_cur = h;
// Take care of dropout
/*
dynet::Expression h_dropped;
if(layers[l].dropout_rate > 0){
if(dropout_active){
dynet::Expression mask = random_bernoulli(cg,
{layers[l].output_dim}, 1 - layers[l].dropout_rate);
h_dropped = cmult(h, mask);
}
else{
h_dropped = h * (1 - layers[l].dropout_rate);
}
}
else{
h_dropped = h;
}
h_cur = h_dropped;
*/
}
return h_cur;
}
inline dynet::Expression MLP::activate(dynet::Expression h, Activation f)
{
switch(f)
{
case LINEAR :
return h;
break;
case RELU :
return rectify(h);
break;
case SIGMOID :
return logistic(h);
break;
case TANH :
return tanh(h);
break;
case SOFTMAX :
return softmax(h);
break;
default :
break;
}
return h;
}
...@@ -39,27 +39,67 @@ MCD::MCD(const std::string & filename) ...@@ -39,27 +39,67 @@ MCD::MCD(const std::string & filename)
Dict * MCD::getDictOfLine(int num) Dict * MCD::getDictOfLine(int num)
{ {
return num2line[num]->dict; auto it = num2line.find(num);
if(it == num2line.end())
{
fprintf(stderr, "ERROR (%s) : requestion line number %d in MCD. Aborting.\n", ERRINFO, num);
exit(1);
}
return it->second->dict;
} }
Dict * MCD::getDictOfLine(const std::string & name) Dict * MCD::getDictOfLine(const std::string & name)
{ {
return name2line[name]->dict; auto it = name2line.find(name);
if(it == name2line.end())
{
fprintf(stderr, "ERROR (%s) : requestion line \'%s\' in MCD. Aborting.\n", ERRINFO, name.c_str());
exit(1);
}
return it->second->dict;
} }
Dict * MCD::getDictOfInputCol(int col) Dict * MCD::getDictOfInputCol(int col)
{ {
return col2line[col]->dict; auto it = col2line.find(col);
if(it == col2line.end())
{
fprintf(stderr, "ERROR (%s) : requestion line of input column %d in MCD. Aborting.\n", ERRINFO, col);
exit(1);
}
return it->second->dict;
} }
int MCD::getLineOfName(const std::string & name) int MCD::getLineOfName(const std::string & name)
{ {
return name2line[name]->num; auto it = name2line.find(name);
if(it == name2line.end())
{
fprintf(stderr, "ERROR (%s) : requestion line %s in MCD. Aborting.\n", ERRINFO, name.c_str());
exit(1);
}
return it->second->num;
} }
int MCD::getLineOfInputCol(int col) int MCD::getLineOfInputCol(int col)
{ {
return col2line[col]->num; auto it = col2line.find(col);
if(it == col2line.end())
{
fprintf(stderr, "ERROR (%s) : requestion line in MCD corresponding to input col %d. Aborting.\n", ERRINFO, col);
exit(1);
}
return it->second->num;
} }
int MCD::getNbInputColumns() int MCD::getNbInputColumns()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment