Skip to content
Snippets Groups Projects
Commit 781db29e authored by Franck Dary's avatar Franck Dary
Browse files

GeneticAlgorithm now uses Nth best MLP, where N is a program parameter

parent 7d98b2f2
No related branches found
No related tags found
No related merge requests found
......@@ -60,6 +60,7 @@ struct ProgramParameters
static bool onlyPrefixes;
static std::map<std::string,std::string> featureModelByClassifier;
static int nbErrorsToShow;
static int nbIndividuals;
private :
......
......@@ -54,3 +54,5 @@ int ProgramParameters::batchSize;
std::string ProgramParameters::loss;
std::map<std::string,std::string> ProgramParameters::featureModelByClassifier;
int ProgramParameters::nbErrorsToShow;
int ProgramParameters::nbIndividuals;
......@@ -39,7 +39,26 @@ void GeneticAlgorithm::init(int nbInputs, const std::string & topology, int nbOu
std::vector<float> GeneticAlgorithm::predict(FeatureModel::FeatureDescription & fd)
{
return generation[0]->mlp.predict(fd);
int toAsk = ProgramParameters::nbIndividuals;
if (toAsk < 0 || toAsk > (int)generation.size())
{
fprintf(stderr, "ERROR (%s) : trying to save \'%d\' individuals out of a population of \'%lu\'. Aborting.\n", ERRINFO, toAsk, generation.size());
exit(1);
}
auto prediction = generation[0]->mlp.predict(fd);
for (int i = 1; i < toAsk; i++)
{
auto otherPrediction = generation[i]->mlp.predict(fd);
for (unsigned int j = 0; j < prediction.size(); j++)
prediction[j] += otherPrediction[j];
}
for (unsigned int j = 0; j < prediction.size(); j++)
prediction[j] /= toAsk;
return prediction;
}
float GeneticAlgorithm::update(FeatureModel::FeatureDescription & fd, int gold)
......@@ -115,11 +134,24 @@ float GeneticAlgorithm::update(FeatureModel::FeatureDescription & fd, int gold)
void GeneticAlgorithm::save(const std::string & filename)
{
int toSave = ProgramParameters::nbIndividuals;
if (toSave < 0 || toSave > (int)generation.size())
{
fprintf(stderr, "ERROR (%s) : trying to save \'%d\' individuals out of a population of \'%lu\'. Aborting.\n", ERRINFO, toSave, generation.size());
exit(1);
}
File * file = new File(filename, "w");
fprintf(file->getDescriptor(), "%u\n", generation[0]->id);
for (int i = 0; i < toSave; i++)
fprintf(file->getDescriptor(), "%u\n", generation[i]->id);
delete file;
generation[0]->mlp.saveStruct(filename);
generation[0]->mlp.saveParameters(filename);
for (int i = 0; i < toSave; i++)
{
generation[i]->mlp.saveStruct(filename);
generation[i]->mlp.saveParameters(filename);
}
}
void GeneticAlgorithm::printTopology(FILE * output)
......@@ -137,18 +169,28 @@ void GeneticAlgorithm::printTopology(FILE * output)
void GeneticAlgorithm::load(const std::string & filename)
{
std::vector<int> ids;
File * file = new File(filename, "r");
unsigned int bestId;
if (fscanf(file->getDescriptor(), "%u\n", &bestId) != 1)
unsigned int id;
while (fscanf(file->getDescriptor(), "%u\n", &id) == 1)
ids.emplace_back(id);
delete file;
if (ids.empty())
{
fprintf(stderr, "ERROR (%s) : expected best id when reading file \'%s\'. Aborting.\n", ERRINFO, filename.c_str());
fprintf(stderr, "ERROR (%s) : Missing MLP\''s ids in file \'%s\'. Aborting.\n", ERRINFO, filename.c_str());
exit(1);
}
delete file;
generation.emplace_back(new Individual(bestId));
generation[0]->mlp.loadStruct(model, filename, 0);
generation[0]->mlp.loadParameters(model, filename);
ProgramParameters::nbIndividuals = ids.size();
for (auto & id : ids)
{
generation.emplace_back(new Individual(id));
generation.back()->mlp.loadStruct(model, filename, generation.size()-1);
generation.back()->mlp.loadParameters(model, filename);
}
}
GeneticAlgorithm::Individual::Individual(dynet::ParameterCollection & model, int nbInputs, const std::string & topology, int nbOutputs) : mlp("MLP_" + std::to_string(idCount))
......@@ -244,7 +286,7 @@ void GeneticAlgorithm::Individual::mutate(float probability)
for (unsigned int k = 0; k < nbValues; k++)
if (choiceWithProbability(probability))
thisValues[k] = getRandomValueInRange(1);
thisValues[k] = getRandomValueInRange(3);
}
}
......@@ -99,7 +99,12 @@ po::options_description getOptionsDescription()
("bias", po::value<float>()->default_value(1e-8),
"bias parameter for the Amsgtad or Adam or Adagrad optimizer");
desc.add(req).add(opt).add(oracle).add(ams);
po::options_description ga("Genetic algorithm related options");
ga.add_options()
("nbIndividuals", po::value<int>()->default_value(5),
"Number of individuals that will be used to take decisions");
desc.add(req).add(opt).add(oracle).add(ams).add(ga);
return desc;
}
......@@ -273,6 +278,7 @@ int main(int argc, char * argv[])
ProgramParameters::beta1 = vm["b1"].as<float>();
ProgramParameters::beta2 = vm["b2"].as<float>();
ProgramParameters::bias = vm["bias"].as<float>();
ProgramParameters::nbIndividuals = vm["nbIndividuals"].as<int>();
ProgramParameters::optimizer = vm["optimizer"].as<std::string>();
ProgramParameters::dicts = vm["dicts"].as<std::string>();
ProgramParameters::loss = vm["loss"].as<std::string>();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment