Skip to content
Snippets Groups Projects
Select Git revision
  • 3a0f7d4f7b7376504ebb87f6277540062b33f56b
  • master default protected
  • fullUD
  • movementInAction
4 results

TrainInfos.hpp

Blame
  • TrainInfos.hpp 1.72 KiB
    /// @file TrainInfos.hpp
    /// @author Franck Dary
    /// @version 1.0
    /// @date 2018-12-20
    
    #ifndef TRAININFOS__H
    #define TRAININFOS__H
    
    #include <string>
    #include <vector>
    #include "ProgramParameters.hpp"
    #include "Config.hpp"
    
    class TrainInfos
    {
      private :
    
      std::string filename;
      int lastEpoch;
      int lastSaved;
      std::map< std::string, std::vector<float> > trainLossesPerClassifierPerEpoch;
      std::map< std::string, std::vector<float> > devLossesPerClassifierPerEpoch;
      std::map< std::string, std::vector<float> > trainScoresPerClassifierPerEpoch;
      std::map< std::string, std::vector<float> > devScoresPerClassifierPerEpoch;
      std::map< std::string, std::vector<bool> > mustSavePerClassifierPerEpoch;
    
      std::map<std::string, float> trainLossCounter;
      std::map<std::string, float> devLossCounter;
    
      std::map<std::string, bool> topologyPrinted;
    
      private :
    
      void readFromFilename();
      void saveToFilename();
      void addTrainScore(const std::string & classifier, float score);
      void addDevScore(const std::string & classifier, float score);
      float computeScoreOnTapes(Config & c, std::vector<std::string> tapes, int from, int to);
    
      public :
    
      std::map<std::string, bool> lastActionWasPredicted;
    
      public :
    
      TrainInfos();
      void addTrainLoss(const std::string & classifier, float loss);
      void addDevLoss(const std::string & classifier, float loss);
      void computeTrainScores(Config & c);
      void computeDevScores(Config & c);
      void computeMustSaves();
      int getEpoch();
      bool isTopologyPrinted(const std::string & classifier);
      void setTopologyPrinted(const std::string & classifier);
      void nextEpoch();
      bool mustSave(const std::string & classifier);
      void printScores(FILE * output);
      void setLastIndexTreated(int index);
    };
    
    #endif