Skip to content
Snippets Groups Projects
Select Git revision
  • 53e3eaac5ef3869798bee5edda60582d41103122
  • master default protected
  • loss
  • producer
4 results

DistanceModule.cpp

Blame
  • GeneticAlgorithm.cpp 3.36 KiB
    #include "GeneticAlgorithm.hpp"
    #include "ProgramParameters.hpp"
    #include "util.hpp"
    
    GeneticAlgorithm::GeneticAlgorithm()
    {
      randomSeed = ProgramParameters::seed;
      initDynet();
    }
    
    GeneticAlgorithm::GeneticAlgorithm(const std::string & filename)
    {
      randomSeed = ProgramParameters::seed;
      initDynet();
    
      load(filename);
    }
    
    void GeneticAlgorithm::init(int nbInputs, const std::string & topology, int nbOutputs)
    {
      auto splited = split(topology, ' ');
      if (splited.size() != 2 || !isNum(splited[0]))
      {
        fprintf(stderr, "ERROR (%s) : wrong topology \'%s\'. Aborting.\n", ERRINFO, topology.c_str());
        exit(1);
      }
    
      int nbElems = std::stoi(splited[0]);
    
      for (int i = 0; i < nbElems; i++)
        generation.emplace_back(new Individual(model, nbInputs, splited[1], nbOutputs));
    
      fprintf(stderr, "Init is done !\n");
    }
    
    std::vector<float> GeneticAlgorithm::predict(FeatureModel::FeatureDescription & fd)
    {
      return generation[0]->mlp.predict(fd);
    }
    
    float GeneticAlgorithm::update(FeatureModel::FeatureDescription & fd, int gold)
    {
      bool haveBeenUpdated = false;
    
      for (auto & individual : generation)
      {
        float loss = individual->mlp.update(fd, gold);
        if (loss != 0.0)
        {
          individual->value = loss2value(loss);
          haveBeenUpdated = true;
        }
      }
    
      if (!haveBeenUpdated)
        return 0.0;
    
      std::sort(generation.begin(), generation.end(),
                [](const std::unique_ptr<Individual> & a, const std::unique_ptr<Individual> & b)
      {
        return a->value > b->value;
      });
    
      fprintf(stderr, "-----------------\n");
      for (auto & individual : generation)
        fprintf(stderr, "%d\t%f\n", individual->id, individual->value);
      fprintf(stderr, "-----------------\n");
    
      for (unsigned int i = 1; i < generation.size(); i++)
      {
        generation[i]->becomeChildOf(generation[0].get());
      }
    }
    
    void GeneticAlgorithm::save(const std::string & filename)
    {
    
    }
    
    void GeneticAlgorithm::printTopology(FILE * output)
    {
      if (generation.empty())
      {
        fprintf(output, "0 x ()\n");
      }
      else
      {
        fprintf(output, "%lu x ", generation.size());
        generation[0]->mlp.printTopology(output);
      }
    }
    
    void GeneticAlgorithm::load(const std::string & filename)
    {
    
    }
    
    GeneticAlgorithm::Individual::Individual(dynet::ParameterCollection & model, int nbInputs, const std::string & topology, int nbOutputs)
    {
      static int id = 0;
      this->id = id++;
      mlp.init(model, nbInputs, topology, nbOutputs);
    }
    
    float GeneticAlgorithm::loss2value(float loss)
    {
      return 1000.0 / loss;
    }
    
    void GeneticAlgorithm::Individual::becomeChildOf(Individual * other)
    {
      auto & thisParameters = mlp.parameters;
      auto & otherParameters = other->mlp.parameters;
    
      if (thisParameters.size() != otherParameters.size())
      {
        fprintf(stderr, "ERROR (%s) : The two individuals are not compatibles. Sizes %lu and %lu. Aborting.\n", ERRINFO, thisParameters.size(), otherParameters.size());
        exit(1);
      }
    
      for (unsigned int i = 0; i < thisParameters.size(); i++)
        for (unsigned int j = 0; j < thisParameters[i].size(); j++)
        {
          auto & thisParameter = thisParameters[i][j];
          auto & otherParameter = otherParameters[i][j];
          float * thisValues = thisParameter.values()->v;
          float * otherValues = otherParameter.values()->v;
          unsigned int nbValues = thisParameter.values()->d.size();
    
          for (unsigned int k = 0; k < nbValues; k++)
            if (rand() % 1000 >= 500)
              thisValues[k] = otherValues[k];
        }
    }