Skip to content
Snippets Groups Projects
Commit d66aeb6b authored by Franck Dary's avatar Franck Dary
Browse files

Training now print MLP topology

parent 1799b40a
No related branches found
No related tags found
No related merge requests found
...@@ -78,6 +78,7 @@ class MLP ...@@ -78,6 +78,7 @@ class MLP
int trainOnBatch(Examples & examples, int start, int end); int trainOnBatch(Examples & examples, int start, int end);
int getScoreOnBatch(Examples & examples, int start, int end); int getScoreOnBatch(Examples & examples, int start, int end);
void save(const std::string & filename); void save(const std::string & filename);
void printTopology(FILE * output);
}; };
#endif #endif
...@@ -441,3 +441,18 @@ MLP::MLP(const std::string & filename) ...@@ -441,3 +441,18 @@ MLP::MLP(const std::string & filename)
load(filename); load(filename);
} }
void MLP::printTopology(FILE * output)
{
fprintf(output, "(");
for(unsigned int i = 0; i < layers.size(); i++)
{
auto & layer = layers[i];
if(i == 0)
fprintf(output, "%d", layer.input_dim);
fprintf(output, "->%d", layer.output_dim);
}
fprintf(output, ")\n");
}
...@@ -54,6 +54,7 @@ class Classifier ...@@ -54,6 +54,7 @@ class Classifier
void initClassifier(Config & config); void initClassifier(Config & config);
void save(); void save();
bool needsTrain(); bool needsTrain();
void printTopology(FILE * output);
}; };
#endif #endif
...@@ -125,7 +125,7 @@ void Classifier::initClassifier(Config & config) ...@@ -125,7 +125,7 @@ void Classifier::initClassifier(Config & config)
} }
int nbInputs = 0; int nbInputs = 0;
int nbHidden = 200; int nbHidden = 500;
int nbOutputs = as->actions.size(); int nbOutputs = as->actions.size();
auto fd = fm->getFeatureDescription(config); auto fd = fm->getFeatureDescription(config);
...@@ -226,3 +226,9 @@ bool Classifier::needsTrain() ...@@ -226,3 +226,9 @@ bool Classifier::needsTrain()
return type == Type::Prediction; return type == Type::Prediction;
} }
void Classifier::printTopology(FILE * output)
{
fprintf(output, "%s topology : ", name.c_str());
mlp->printTopology(output);
}
...@@ -127,6 +127,11 @@ void Trainer::trainBatched(int nbIter, int batchSize, bool mustShuffle) ...@@ -127,6 +127,11 @@ void Trainer::trainBatched(int nbIter, int batchSize, bool mustShuffle)
if(devMcd && devConfig) if(devMcd && devConfig)
getExamplesByClassifier(devExamples, *devConfig); getExamplesByClassifier(devExamples, *devConfig);
auto & classifiers = tm.getClassifiers();
for(Classifier * cla : classifiers)
if(cla->needsTrain())
cla->printTopology(stderr);
std::map< std::string, std::vector<float> > trainScores; std::map< std::string, std::vector<float> > trainScores;
std::map< std::string, std::vector<float> > devScores; std::map< std::string, std::vector<float> > devScores;
std::map<std::string, int> bestIter; std::map<std::string, int> bestIter;
...@@ -154,7 +159,6 @@ void Trainer::trainBatched(int nbIter, int batchSize, bool mustShuffle) ...@@ -154,7 +159,6 @@ void Trainer::trainBatched(int nbIter, int batchSize, bool mustShuffle)
printIterationScores(stderr, nbExamplesTrain, nbExamplesDev, printIterationScores(stderr, nbExamplesTrain, nbExamplesDev,
trainScores, devScores, bestIter, nbIter, i); trainScores, devScores, bestIter, nbIter, i);
auto & classifiers = tm.getClassifiers();
for(Classifier * cla : classifiers) for(Classifier * cla : classifiers)
if(cla->needsTrain()) if(cla->needsTrain())
if(bestIter[cla->name] == i) if(bestIter[cla->name] == i)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment