Skip to content
Snippets Groups Projects
Trainer.hpp 1017 B
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #ifndef TRAINER__H
    #define TRAINER__H
    
    #include "ReadingMachine.hpp"
    #include "ConfigDataset.hpp"
    #include "SubConfig.hpp"
    
    class Trainer
    {
      private :
    
    
    Franck Dary's avatar
    Franck Dary committed
      using Dataset = ConfigDataset;
      using DataLoader = std::unique_ptr<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler>, std::default_delete<torch::data::StatelessDataLoader<torch::data::datasets::MapDataset<Dataset, torch::data::transforms::Stack<torch::data::Example<> > >, torch::data::samplers::RandomSampler> > >;
    
      private :
    
    
    Franck Dary's avatar
    Franck Dary committed
      ReadingMachine & machine;
    
    Franck Dary's avatar
    Franck Dary committed
      DataLoader dataLoader{nullptr};
    
    Franck Dary's avatar
    Franck Dary committed
      std::unique_ptr<torch::optim::Adam> denseOptimizer;
      std::unique_ptr<torch::optim::SparseAdam> sparseOptimizer;
    
    Franck Dary's avatar
    Franck Dary committed
      std::size_t epochNumber{0};
      int batchSize{100};
      int nbExamples{0};
    
    Franck Dary's avatar
    Franck Dary committed
    
      public :
    
      Trainer(ReadingMachine & machine);
    
    Franck Dary's avatar
    Franck Dary committed
      void createDataset(SubConfig & goldConfig, bool debug);
      float epoch(bool printAdvancement);
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
    };
    
    #endif