-
Franck Dary authoredFranck Dary authored
Classifier.hpp 1.31 KiB
#ifndef CLASSIFIER__H
#define CLASSIFIER__H
#include <string>
#include "TransitionSet.hpp"
#include "NeuralNetwork.hpp"
class Classifier
{
private :
std::vector<std::string> knownOptimizers{
"Adam {lr beta1 beta2 eps decay amsgrad}",
"Adagrad {lr lr_decay weight_decay eps}",
};
std::string name;
std::map<std::string, std::unique_ptr<TransitionSet>> transitionSets;
std::map<std::string, float> lossMultipliers;
std::shared_ptr<NeuralNetworkImpl> nn;
std::unique_ptr<torch::optim::Optimizer> optimizer;
std::string optimizerType, optimizerParameters;
std::string state;
private :
void initNeuralNetwork(const std::vector<std::string> & definition);
void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
public :
Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition);
TransitionSet & getTransitionSet();
NeuralNetwork & getNN();
const std::string & getName() const;
int getNbParameters() const;
void resetOptimizer();
void loadOptimizer(std::filesystem::path path);
void saveOptimizer(std::filesystem::path path);
torch::optim::Optimizer & getOptimizer();
void setState(const std::string & state);
float getLossMultiplier();
};
#endif