Skip to content
Snippets Groups Projects
Classifier.hpp 1.31 KiB
#ifndef CLASSIFIER__H
#define CLASSIFIER__H

#include <string>
#include "TransitionSet.hpp"
#include "NeuralNetwork.hpp"

class Classifier
{
  private :

  std::vector<std::string> knownOptimizers{
    "Adam {lr beta1 beta2 eps decay amsgrad}",
    "Adagrad {lr lr_decay weight_decay eps}",
  };

  std::string name;
  std::map<std::string, std::unique_ptr<TransitionSet>> transitionSets;
  std::map<std::string, float> lossMultipliers;
  std::shared_ptr<NeuralNetworkImpl> nn;
  std::unique_ptr<torch::optim::Optimizer> optimizer;
  std::string optimizerType, optimizerParameters;
  std::string state;

  private :

  void initNeuralNetwork(const std::vector<std::string> & definition);
  void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);

  public :

  Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition);
  TransitionSet & getTransitionSet();
  NeuralNetwork & getNN();
  const std::string & getName() const;
  int getNbParameters() const;
  void resetOptimizer();
  void loadOptimizer(std::filesystem::path path);
  void saveOptimizer(std::filesystem::path path);
  torch::optim::Optimizer & getOptimizer();
  void setState(const std::string & state);
  float getLossMultiplier();
};

#endif