Skip to content
Snippets Groups Projects
Classifier.cpp 5.13 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "util.hpp"
    
    Franck Dary's avatar
    Franck Dary committed
    #include "RandomNetwork.hpp"
    
    #include "ModularNetwork.hpp"
    
    Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition)
    
      if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Transitions :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[0], [this,&path](auto sm)
            {
    
              auto splited = util::split(sm.str(1), ' ');
    
              for (auto & ss : splited)
              {
                std::vector<std::string> tsFiles;
                std::vector<std::string> states;
                for (auto & elem : util::split(ss, ','))
                  if (std::filesystem::path(elem).extension().empty())
                    states.emplace_back(elem);
                  else
                    tsFiles.emplace_back(path.parent_path() / elem);
                if (tsFiles.empty())
                  util::myThrow(fmt::format("invalid '{}' no .ts files specified", ss));
                if (states.empty())
                  util::myThrow(fmt::format("invalid '{}' no states specified", ss));
    
                for (auto & stateName : states)
                {
                  if (transitionSets.count(stateName))
                    util::myThrow(fmt::format("state '{}' already assigned", stateName));
    
                  this->transitionSets.emplace(stateName, new TransitionSet(tsFiles));
                }
              }
    
    
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}"));
    
      initNeuralNetwork(definition);
    
    int Classifier::getNbParameters() const
    {
      int nbParameters = 0;
    
      for (auto & t : nn->parameters())
        nbParameters += torch::numel(t);
    
      return nbParameters;
    }
    
    
    TransitionSet & Classifier::getTransitionSet()
    {
    
      if (!transitionSets.count(state))
        util::myThrow(fmt::format("cannot find transition set for state '{}'", state));
    
      return *transitionSets[state];
    
    NeuralNetwork & Classifier::getNN()
    
    Franck Dary's avatar
    Franck Dary committed
    {
    
      return reinterpret_cast<NeuralNetwork&>(nn);
    
    Franck Dary's avatar
    Franck Dary committed
    }
    
    
    Franck Dary's avatar
    Franck Dary committed
    const std::string & Classifier::getName() const
    {
      return name;
    }
    
    
    void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
    {
    
      std::map<std::string,std::size_t> nbOutputsPerState;
      for (auto & it : this->transitionSets)
        nbOutputsPerState[it.first] = it.second->size();
    
    
      std::size_t curIndex = 1;
    
      std::string networkType;
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Network type :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&networkType](auto sm)
            {
              networkType = sm.str(1);
              curIndex++;
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Network type :) networkType"));
    
      if (networkType == "Random")
    
        this->nn.reset(new RandomNetworkImpl(nbOutputsPerState));
    
      else if (networkType == "Modular")
        initModular(definition, curIndex, nbOutputsPerState);
    
    Franck Dary's avatar
    Franck Dary committed
        util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, Modular'", networkType));
    
      if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Optimizer :|)(?:(?:\\s|\\t)*)(.*) \\{(.*)\\}"), definition[curIndex], [&curIndex,this](auto sm)
            {
    
              optimizerType = sm.str(1);
              optimizerParameters = sm.str(2);
    
            }))
        util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}"));
    
    void Classifier::loadOptimizer(std::filesystem::path path)
    {
      torch::load(*optimizer, path);
    }
    
    void Classifier::saveOptimizer(std::filesystem::path path)
    {
      torch::save(*optimizer, path);
    }
    
    torch::optim::Adam & Classifier::getOptimizer()
    {
      return *optimizer;
    }
    
    
    void Classifier::setState(const std::string & state)
    {
      this->state = state;
      nn->setState(state);
    }
    
    void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState)
    {
      std::string anyBlanks = "(?:(?:\\s|\\t)*)";
      std::regex endRegex(fmt::format("{}End{}",anyBlanks,anyBlanks));
      std::vector<std::string> modulesDefinitions;
    
      for (; curIndex < definition.size(); curIndex++)
      {
        if (util::doIfNameMatch(endRegex,definition[curIndex],[](auto sm){}))
        {
          curIndex++;
          break;
        }
        modulesDefinitions.emplace_back(definition[curIndex]);
      }
    
      this->nn.reset(new ModularNetworkImpl(nbOutputsPerState, modulesDefinitions));
    }
    
    
    void Classifier::resetOptimizer()
    {
      std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'";
      if (optimizerType == "Adam")
      {
        auto splited = util::split(optimizerParameters, ' ');
        if (splited.size() != 6 or (splited.back() != "false" and splited.back() != "true"))
          util::myThrow(expected);
     
        optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").betas({std::stof(splited[1]),std::stof(splited[2])}).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4]))));
      }
      else
        util::myThrow(expected);
    }