Skip to content
Snippets Groups Projects
Classifier.cpp 7.73 KiB
#include "Classifier.hpp"
#include "util.hpp"
#include "RandomNetwork.hpp"
#include "ModularNetwork.hpp"

Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train) : path(path)
{
  this->name = name;
  if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Transitions :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[0], [this,&path](auto sm)
        {
          auto splited = util::split(sm.str(1), ' ');

          for (auto & ss : splited)
          {
            std::vector<std::string> tsFiles;
            std::vector<std::string> curStates;
            for (auto & elem : util::split(ss, ','))
              if (std::filesystem::path(elem).extension().empty())
              {
                states.emplace_back(elem);
                curStates.emplace_back(elem);
              }
              else
                tsFiles.emplace_back(path / elem);
            if (tsFiles.empty())
              util::myThrow(fmt::format("invalid '{}' no .ts files specified", ss));
            if (curStates.empty())
              util::myThrow(fmt::format("invalid '{}' no states specified", ss));

            for (auto & stateName : curStates)
            {
              if (transitionSets.count(stateName))
                util::myThrow(fmt::format("state '{}' already assigned", stateName));

              this->transitionSets.emplace(stateName, new TransitionSet(tsFiles));
            }
          }

        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}"));

  for (auto & it : this->transitionSets)
    lossMultipliers[it.first] = 1.0;

  if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LossMultiplier :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[1], [this](auto sm)
        {
          auto pairs = util::split(sm.str(1), ' ');
          for (auto & it : pairs)
          {
            auto splited = util::split(it, ',');
            if (splited.size() != 2)
              util::myThrow(fmt::format("invalid '{}' must have 2 elements", it));
            try
            {
              lossMultipliers.at(splited[0]) = std::stof(splited[1]);
            } catch (std::exception & e)
            {
              util::myThrow(fmt::format("caugh '{}' in '{}'", e.what(), it));
            }
          }
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[1], "(LossMultiplier :) {state1,multiplier1 state2,multiplier2...}"));

  initNeuralNetwork(definition);

  getNN()->loadDicts(path);
  getNN()->registerEmbeddings();

  getNN()->to(NeuralNetworkImpl::device);
  if (!train)
    torch::load(getNN(), getBestFilename());
  else if (std::filesystem::exists(getLastFilename()))
  {
    torch::load(getNN(), getLastFilename());
    resetOptimizer();
    loadOptimizer();
  }
}

int Classifier::getNbParameters() const
{
  int nbParameters = 0;

  for (auto & t : nn->parameters())
    nbParameters += torch::numel(t);

  return nbParameters;
}

TransitionSet & Classifier::getTransitionSet()
{
  if (!transitionSets.count(state))
    util::myThrow(fmt::format("cannot find transition set for state '{}'", state));

  return *transitionSets[state];
}

NeuralNetwork & Classifier::getNN()
{
  return reinterpret_cast<NeuralNetwork&>(nn);
}

const std::string & Classifier::getName() const
{
  return name;
}

void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
{
  std::map<std::string,std::size_t> nbOutputsPerState;
  for (auto & it : this->transitionSets)
    nbOutputsPerState[it.first] = it.second->size();

  std::size_t curIndex = 2;

  std::string networkType;
  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Network type :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&networkType](auto sm)
        {
          networkType = sm.str(1);
          curIndex++;
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Network type :) networkType"));

  if (networkType == "Random")
    this->nn.reset(new RandomNetworkImpl(this->name, nbOutputsPerState));
  else if (networkType == "Modular")
    initModular(definition, curIndex, nbOutputsPerState);
  else
    util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, Modular'", networkType));

  if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Optimizer :|)(?:(?:\\s|\\t)*)(.*) \\{(.*)\\}"), definition[curIndex], [&curIndex,this](auto sm)
        {
          optimizerType = sm.str(1);
          optimizerParameters = sm.str(2);
        }))
    util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) " + util::join("|", knownOptimizers)));
}

void Classifier::loadOptimizer()
{
  auto optimizerPath = std::filesystem::path(fmt::format("{}/{}_optimizer.pt", path.string(), name));
  if (std::filesystem::exists(optimizerPath))
    torch::load(*optimizer, optimizerPath);
}

void Classifier::saveOptimizer()
{
  torch::save(*optimizer, fmt::format("{}/{}_optimizer.pt", path.string(), name));
}

torch::optim::Optimizer & Classifier::getOptimizer()
{
  return *optimizer;
}

void Classifier::setState(const std::string & state)
{
  this->state = state;
  nn->setState(state);
}

void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState)
{
  std::string anyBlanks = "(?:(?:\\s|\\t)*)";
  std::regex endRegex(fmt::format("{}End{}",anyBlanks,anyBlanks));
  std::vector<std::string> modulesDefinitions;

  for (; curIndex < definition.size(); curIndex++)
  {
    if (util::doIfNameMatch(endRegex,definition[curIndex],[](auto sm){}))
    {
      curIndex++;
      break;
    }
    modulesDefinitions.emplace_back(definition[curIndex]);
  }

  this->nn.reset(new ModularNetworkImpl(this->name, nbOutputsPerState, modulesDefinitions, path));
}

void Classifier::resetOptimizer()
{
  std::string expected = "expected '(Optimizer :) " + util::join("| ", knownOptimizers);
  if (optimizerType == "Adam")
  {
    auto splited = util::split(optimizerParameters, ' ');
    if (splited.size() != 6 or (splited.back() != "false" and splited.back() != "true"))
      util::myThrow(expected);
 
    optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").betas({std::stof(splited[1]),std::stof(splited[2])}).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4]))));
  }
  else if (optimizerType == "Adagrad")
  {
    auto splited = util::split(optimizerParameters, ' ');
    if (splited.size() != 4)
      util::myThrow(expected);
 
    optimizer.reset(new torch::optim::Adagrad(getNN()->parameters(), torch::optim::AdagradOptions(std::stof(splited[0])).lr_decay(std::stof(splited[1])).eps(std::stof(splited[3])).weight_decay(std::stof(splited[2]))));
  }
  else
    util::myThrow(expected);
}

float Classifier::getLossMultiplier()
{
  return lossMultipliers.at(state);
}

const std::vector<std::string> & Classifier::getStates() const
{
  return states;
}

void Classifier::saveDicts()
{
  getNN()->saveDicts(path);
}

std::string Classifier::getBestFilename() const
{
  return fmt::format("{}/{}_best.pt", path.string(), name);
}

std::string Classifier::getLastFilename() const
{
  return fmt::format("{}/{}_last.pt", path.string(), name);
}

void Classifier::saveBest()
{
  getNN()->to(torch::kCPU);
  torch::save(getNN(), getBestFilename());
  getNN()->to(NeuralNetworkImpl::device);
}

void Classifier::saveLast()
{
  getNN()->to(torch::kCPU);
  torch::save(getNN(), getLastFilename());
  getNN()->to(NeuralNetworkImpl::device);
  saveOptimizer();
}