Skip to content
Snippets Groups Projects
Classifier.cpp 6.13 KiB
Newer Older
#include "util.hpp"
Franck Dary's avatar
Franck Dary committed
#include "RandomNetwork.hpp"
#include "ModularNetwork.hpp"
Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition)
  if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Transitions :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[0], [this,&path](auto sm)
        {
          auto splited = util::split(sm.str(1), ' ');
          for (auto & ss : splited)
          {
            std::vector<std::string> tsFiles;
            std::vector<std::string> states;
            for (auto & elem : util::split(ss, ','))
              if (std::filesystem::path(elem).extension().empty())
                states.emplace_back(elem);
              else
                tsFiles.emplace_back(path.parent_path() / elem);
            if (tsFiles.empty())
              util::myThrow(fmt::format("invalid '{}' no .ts files specified", ss));
            if (states.empty())
              util::myThrow(fmt::format("invalid '{}' no states specified", ss));

            for (auto & stateName : states)
            {
              if (transitionSets.count(stateName))
                util::myThrow(fmt::format("state '{}' already assigned", stateName));

              this->transitionSets.emplace(stateName, new TransitionSet(tsFiles));
            }
          }

        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}"));

Franck Dary's avatar
Franck Dary committed
  for (auto & it : this->transitionSets)
    lossMultipliers[it.first] = 1.0;

  if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LossMultiplier :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[1], [this](auto sm)
Franck Dary's avatar
Franck Dary committed
        {
          auto pairs = util::split(sm.str(1), ' ');
          for (auto & it : pairs)
          {
            auto splited = util::split(it, ',');
            if (splited.size() != 2)
              util::myThrow(fmt::format("invalid '{}' must have 2 elements", it));
            try
            {
              lossMultipliers.at(splited[0]) = std::stof(splited[1]);
            } catch (std::exception & e)
            {
              util::myThrow(fmt::format("caugh '{}' in '{}'", e.what(), it));
            }
          }
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[1], "(LossMultiplier :) {state1,multiplier1 state2,multiplier2...}"));
Franck Dary's avatar
Franck Dary committed

  initNeuralNetwork(definition);
int Classifier::getNbParameters() const
{
  int nbParameters = 0;

  for (auto & t : nn->parameters())
    nbParameters += torch::numel(t);

  return nbParameters;
}

TransitionSet & Classifier::getTransitionSet()
{
  if (!transitionSets.count(state))
    util::myThrow(fmt::format("cannot find transition set for state '{}'", state));

  return *transitionSets[state];
NeuralNetwork & Classifier::getNN()
Franck Dary's avatar
Franck Dary committed
{
  return reinterpret_cast<NeuralNetwork&>(nn);
Franck Dary's avatar
Franck Dary committed
}

Franck Dary's avatar
Franck Dary committed
const std::string & Classifier::getName() const
{
  return name;
}

void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
{
  std::map<std::string,std::size_t> nbOutputsPerState;
  for (auto & it : this->transitionSets)
    nbOutputsPerState[it.first] = it.second->size();

Franck Dary's avatar
Franck Dary committed
  std::size_t curIndex = 2;

  std::string networkType;
  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Network type :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&networkType](auto sm)
        {
          networkType = sm.str(1);
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Network type :) networkType"));

  if (networkType == "Random")
    this->nn.reset(new RandomNetworkImpl(this->name, nbOutputsPerState));
  else if (networkType == "Modular")
    initModular(definition, curIndex, nbOutputsPerState);
Franck Dary's avatar
Franck Dary committed
    util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, Modular'", networkType));
  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Optimizer :|)(?:(?:\\s|\\t)*)(.*) \\{(.*)\\}"), definition[curIndex], [&curIndex,this](auto sm)
        {
          optimizerType = sm.str(1);
          optimizerParameters = sm.str(2);
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}"));
void Classifier::loadOptimizer(std::filesystem::path path)
{
  torch::load(*optimizer, path);
}

void Classifier::saveOptimizer(std::filesystem::path path)
{
  torch::save(*optimizer, path);
}

torch::optim::Adam & Classifier::getOptimizer()
{
  return *optimizer;
}

void Classifier::setState(const std::string & state)
{
  this->state = state;
  nn->setState(state);
}
void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState)
{
  std::string anyBlanks = "(?:(?:\\s|\\t)*)";
  std::regex endRegex(fmt::format("{}End{}",anyBlanks,anyBlanks));
  std::vector<std::string> modulesDefinitions;

  for (; curIndex < definition.size(); curIndex++)
  {
    if (util::doIfNameMatch(endRegex,definition[curIndex],[](auto sm){}))
    {
      curIndex++;
      break;
    }
    modulesDefinitions.emplace_back(definition[curIndex]);
  }

  this->nn.reset(new ModularNetworkImpl(this->name, nbOutputsPerState, modulesDefinitions));
void Classifier::resetOptimizer()
{
  std::string expected = "expected '(Optimizer :) Adam {lr beta1 beta2 eps decay amsgrad}'";
  if (optimizerType == "Adam")
  {
    auto splited = util::split(optimizerParameters, ' ');
    if (splited.size() != 6 or (splited.back() != "false" and splited.back() != "true"))
      util::myThrow(expected);
 
    optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").betas({std::stof(splited[1]),std::stof(splited[2])}).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4]))));
  }
  else
    util::myThrow(expected);
}

Franck Dary's avatar
Franck Dary committed
float Classifier::getLossMultiplier()
{
  return lossMultipliers.at(state);
}