from argparse import ArgumentParser import os import json import sys import pytorch_lightning import warnings warnings.filterwarnings('ignore', message='Displayed epoch numbers in the progress bar start from.*') warnings.filterwarnings('ignore', message='.*does not have many workers which may be a bottleneck.*') from model import Model from logger import Logger def main(hparams): pytorch_lightning.seed_everything(hparams.seed) if hparams.loss == 'f1': logger = Logger(hparams.name, checkpoint_metric='fscore', metric_aggregator=max) elif hparams.loss == 'bce': logger = Logger(hparams.name, checkpoint_metric='bce', metric_aggregator=min) else: raise ValueError('invalid loss "%s"' % hparams.loss) model = Model(hparams) model.custom_logger = logger logger.log_hparams(hparams) trainer = pytorch_lightning.Trainer( max_nb_epochs=hparams.epochs, gpus=hparams.gpus, nb_gpu_nodes=hparams.nodes, check_val_every_n_epoch=1, progress_bar_refresh_rate=1, logger=None, checkpoint_callback=None, num_sanity_val_steps=0, fast_dev_run=hparams.fast_dev_run, deterministic=True, ) logger.set_save_function(trainer.save_checkpoint) trainer.fit(model) model = None model = Model.load_from_checkpoint(logger.best_checkpoint) model.custom_logger = logger trainer.test(model) if __name__ == '__main__': parser = ArgumentParser(add_help=False) parser.add_argument('--gpus', type=str, default=None, help='ids of GPUs to use (use -1 for all available GPUs, defaults to CPU)') parser.add_argument('--nodes', type=int, default=1, help='number of computation nodes for distributed training (see lightning docs, defaults to 1)') parser.add_argument('--name', type=str, required=True, help='name of experiment (required)') parser.add_argument('--fast_dev_run', default=False, action='store_true', help='run a single batch through the whole training loop for catching bugs') parser.add_argument('--seed', default=123, type=int, help='set global random seed (defaults to 123)') parser = Model.add_model_specific_args(parser) command_line = 'python ' + ' '.join(sys.argv) hparams = parser.parse_args() hparams.cmd = command_line main(hparams)