Skip to content
Snippets Groups Projects
Commit ac654e14 authored by Franck Dary's avatar Franck Dary
Browse files

Added MLP topology to .cla files

parent b0f2f464
Branches
No related tags found
No related merge requests found
...@@ -51,7 +51,7 @@ class MLP ...@@ -51,7 +51,7 @@ class MLP
std::map< Dict*, std::pair<dynet::LookupParameter, std::map<void*, unsigned int> > > lookupParameters; std::map< Dict*, std::pair<dynet::LookupParameter, std::map<void*, unsigned int> > > lookupParameters;
dynet::ParameterCollection model; dynet::ParameterCollection model;
dynet::AmsgradTrainer trainer; std::unique_ptr<dynet::AmsgradTrainer> trainer;
bool trainMode; bool trainMode;
bool dropoutActive; bool dropoutActive;
...@@ -74,6 +74,7 @@ class MLP ...@@ -74,6 +74,7 @@ class MLP
public : public :
MLP(std::vector<Layer> layers); MLP(std::vector<Layer> layers);
MLP(int nbInputs, const std::string & topology, int nbOutputs);
MLP(const std::string & filename); MLP(const std::string & filename);
std::vector<float> predict(FeatureModel::FeatureDescription & fd); std::vector<float> predict(FeatureModel::FeatureDescription & fd);
......
...@@ -78,9 +78,53 @@ void MLP::initDynet() ...@@ -78,9 +78,53 @@ void MLP::initDynet()
dynet::initialize(getDefaultParams()); dynet::initialize(getDefaultParams());
} }
MLP::MLP(int nbInputs, const std::string & topology, int nbOutputs)
{
std::string topo = topology;
std::replace(topo.begin(), topo.end(), '(', ' ');
std::replace(topo.begin(), topo.end(), ')', ' ');
auto groups = split(topo);
for (auto group : groups)
{
if(group.empty())
continue;
std::replace(group.begin(), group.end(), ',', ' ');
auto layer = split(group);
if (layer.size() != 3)
{
fprintf(stderr, "ERROR (%s) : invalid topology \'%s\'. Aborting.\n", ERRINFO, topology.c_str());
exit(1);
}
int input = layers.empty() ? nbInputs : layers.back().output_dim;
int output = std::stoi(layer[0]);
float dropout = std::stof(layer[2]);
layers.emplace_back(input, output, dropout, str2activation(layer[1]));
}
layers.emplace_back(layers.back().output_dim, nbOutputs, 0.0, Activation::LINEAR);
trainer.reset(new dynet::AmsgradTrainer(model, 0.001, 0.9, 0.999, 1e-8));
initDynet();
trainMode = true;
dropoutActive = true;
checkLayersCompatibility();
for(Layer layer : layers)
addLayerToModel(layer);
}
MLP::MLP(std::vector<Layer> layers) MLP::MLP(std::vector<Layer> layers)
: layers(layers), trainer(model, 0.001, 0.9, 0.999, 1e-8) : layers(layers)
{ {
trainer.reset(new dynet::AmsgradTrainer(model, 0.001, 0.9, 0.999, 1e-8));
initDynet(); initDynet();
trainMode = true; trainMode = true;
...@@ -308,7 +352,7 @@ int MLP::trainOnBatch(Examples & examples, int start, int end) ...@@ -308,7 +352,7 @@ int MLP::trainOnBatch(Examples & examples, int start, int end)
dynet::Expression batchedLoss = pickneglogsoftmax(output, goldClasses); dynet::Expression batchedLoss = pickneglogsoftmax(output, goldClasses);
dynet::Expression loss = sum_batches(batchedLoss); dynet::Expression loss = sum_batches(batchedLoss);
cg.backward(loss); cg.backward(loss);
trainer.update(); trainer->update();
} }
int nbCorrect = 0; int nbCorrect = 0;
...@@ -450,8 +494,8 @@ void MLP::loadParameters(const std::string & filename) ...@@ -450,8 +494,8 @@ void MLP::loadParameters(const std::string & filename)
} }
MLP::MLP(const std::string & filename) MLP::MLP(const std::string & filename)
: trainer(model, 0.001, 0.9, 0.999, 1e-8)
{ {
trainer.reset(new dynet::AmsgradTrainer(model, 0.001, 0.9, 0.999, 1e-8));
initDynet(); initDynet();
trainMode = false; trainMode = false;
......
...@@ -30,6 +30,7 @@ class Classifier ...@@ -30,6 +30,7 @@ class Classifier
std::unique_ptr<FeatureModel> fm; std::unique_ptr<FeatureModel> fm;
std::unique_ptr<ActionSet> as; std::unique_ptr<ActionSet> as;
std::unique_ptr<MLP> mlp; std::unique_ptr<MLP> mlp;
std::string topology;
Oracle * oracle; Oracle * oracle;
public : public :
......
...@@ -61,6 +61,11 @@ Classifier::Classifier(const std::string & filename, bool trainMode, const std:: ...@@ -61,6 +61,11 @@ Classifier::Classifier(const std::string & filename, bool trainMode, const std::
badFormatAndAbort(ERRINFO); badFormatAndAbort(ERRINFO);
as.reset(new ActionSet(expPath + buffer, false)); as.reset(new ActionSet(expPath + buffer, false));
if(fscanf(fd, "Topology : %s\n", buffer) != 1)
badFormatAndAbort(ERRINFO);
topology = buffer;
} }
Classifier::Type Classifier::str2type(const std::string & s) Classifier::Type Classifier::str2type(const std::string & s)
...@@ -120,7 +125,6 @@ void Classifier::initClassifier(Config & config) ...@@ -120,7 +125,6 @@ void Classifier::initClassifier(Config & config)
} }
int nbInputs = 0; int nbInputs = 0;
int nbHidden = 300;
int nbOutputs = as->actions.size(); int nbOutputs = as->actions.size();
auto fd = fm->getFeatureDescription(config); auto fd = fm->getFeatureDescription(config);
...@@ -128,8 +132,7 @@ void Classifier::initClassifier(Config & config) ...@@ -128,8 +132,7 @@ void Classifier::initClassifier(Config & config)
for (auto feat : fd.values) for (auto feat : fd.values)
nbInputs += feat.vec->size(); nbInputs += feat.vec->size();
mlp.reset(new MLP({{nbInputs, nbHidden, 0.3, MLP::Activation::RELU}, mlp.reset(new MLP(nbInputs, topology, nbOutputs));
{nbHidden, nbOutputs, 0.0, MLP::Activation::LINEAR}}));
} }
FeatureModel::FeatureDescription Classifier::getFeatureDescription(Config & config) FeatureModel::FeatureDescription Classifier::getFeatureDescription(Config & config)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment