Skip to content
Snippets Groups Projects
Trainer.hpp 2.02 KiB
#ifndef TRAINER__H
#define TRAINER__H

#include "ReadingMachine.hpp"
#include "ConfigDataset.hpp"
#include "SubConfig.hpp"

class Trainer
{
  public :

  enum TrainAction
  {
    ExtractGold,
    ExtractDynamic,
    DeleteExamples,
    ResetOptimizer,
    ResetParameters,
    Save
  };
  using TrainStrategy = std::map<std::size_t, std::set<TrainAction>>;
  static TrainAction str2TrainAction(const std::string & s);

  private :

  static constexpr std::size_t safetyNbExamplesMax = 10*1000*1000;

  struct Examples
  {
    std::vector<torch::Tensor> contexts;
    std::vector<torch::Tensor> classes;

    int currentExampleIndex{0};
    int lastSavedIndex{0};

    void saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int currentEpoch, bool dynamicOracle);
    void addContext(std::vector<std::vector<long>> & context);
    void addClass(const LossFunction & lossFct, int nbClasses, const std::vector<long> & goldIndexes);
  };

  private :

  using Dataset = ConfigDataset;
  using DataLoader = std::unique_ptr<torch::data::StatefulDataLoader<Dataset>>;

  private :

  ReadingMachine & machine;
  std::unique_ptr<Dataset> trainDataset{nullptr};
  std::unique_ptr<Dataset> devDataset{nullptr};
  DataLoader dataLoader{nullptr};
  DataLoader devDataLoader{nullptr};
  std::size_t epochNumber{0};
  int batchSize;

  private :

  void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold);
  float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples);

  public :

  Trainer(ReadingMachine & machine, int batchSize);
  void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold);
  void extractActionSequence(BaseConfig & config);
  void makeDataLoader(std::filesystem::path dir);
  void makeDevDataLoader(std::filesystem::path dir);
  float epoch(bool printAdvancement);
  float evalOnDev(bool printAdvancement);
};