Skip to content
Snippets Groups Projects
Select Git revision
  • cdc9ed544399c72b0ee5f84ee4817fd7de415a4f
  • master default protected
  • loss
  • producer
4 results

Trainer.cpp

Blame
  • Trainer.cpp 15.17 KiB
    #include "Trainer.hpp"
    #include "SubConfig.hpp"
    
    Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize)
    {
    }
    
    void Trainer::makeDataLoader(std::filesystem::path dir)
    {
      trainDataset.reset(new Dataset(dir));
      dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
    }
    
    void Trainer::makeDevDataLoader(std::filesystem::path dir)
    {
      devDataset.reset(new Dataset(dir));
      devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
    }
    
    void Trainer::createDataset(std::vector<BaseConfig> & goldConfigs, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold)
    {
      std::vector<SubConfig> configs;
      for (auto & goldConfig : goldConfigs)
        configs.emplace_back(goldConfig, goldConfig.getNbLines());
    
      machine.trainMode(false);
    
      extractExamples(configs, debug, dir, epoch, dynamicOracle, explorationThreshold);
    
      machine.saveDicts();
    }
    
    void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold)
    {
      torch::AutoGradMode useGrad(false);
    
      int maxNbExamplesPerFile = 50000;
      std::map<std::string, Examples> examplesPerState;
    
      std::filesystem::create_directories(dir);
    
      auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}.{}", epoch, dynamicOracle);
    
      if (std::filesystem::exists(currentEpochAllExtractedFile))
        return;
    
      fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : "");
    
      int totalNbExamples = 0;
    
      for (auto & config : configs)
      {
        config.addPredicted(machine.getPredicted());
        config.setStrategy(machine.getStrategyDefinition());
        config.setState(config.getStrategy().getInitialState());
        machine.getClassifier(config.getState())->setState(config.getState());
    
        while (true)
        {
          if (debug)
            config.printForDebug(stderr);
    
          if (machine.hasSplitWordTransitionSet())
            config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
    
          auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config);
          config.setAppliableTransitions(appliableTransitions);
    
          std::vector<std::vector<long>> context;