diff --git a/code/train.py b/code/train.py index 6daa502ac3dc41ec036e2642c2a658965221475c..475843c9dbd005935fe7e61d77b7467a781af1fc 100644 --- a/code/train.py +++ b/code/train.py @@ -30,6 +30,7 @@ if __name__ == "__main__": DEFAULT_TEST_SIZE = 0.2 DEFAULT_RANDOM_SEED_NUMBER = 1 DEFAULT_TRAIN_ON_SUBSET = 'train' + DEFAULT_DISABLE_PROGRESS = False begin_random_seed_range = 1 end_random_seed_range = 2000 @@ -46,6 +47,7 @@ if __name__ == "__main__": parser.add_argument('--random_seed_number', nargs='?', type=int, default=DEFAULT_RANDOM_SEED_NUMBER, help='Number of random seeds used.') parser.add_argument('--seeds', nargs='+', type=int, default=None, help='Specific a list of seeds instead of generate them randomly') parser.add_argument('--train_on_subset', nargs='?', type=str, default=DEFAULT_TRAIN_ON_SUBSET, help='Specify on witch subset the model will be trained (either train or dev).') + parser.add_argument('--disable_progress', action='store_true', default=DEFAULT_DISABLE_PROGRESS, help='Disable the progress bars.') args = parser.parse_args() pathlib.Path(args.models_dir).mkdir(parents=True, exist_ok=True) @@ -68,7 +70,7 @@ if __name__ == "__main__": logger.info('Experiment id: {}'.format(experiment_id_str)) - with tqdm(seeds) as seed_bar: + with tqdm(seeds, disable=args.disable_progress) as seed_bar: for seed in seed_bar: seed_bar.set_description('seed={}'.format(seed)) seed_str = str(seed) @@ -90,7 +92,7 @@ if __name__ == "__main__": trainer = Trainer(dataset) - with tqdm(args.extracted_forest_size) as extracted_forest_size_bar: + with tqdm(args.extracted_forest_size, disable=args.disable_progress) as extracted_forest_size_bar: for extracted_forest_size in extracted_forest_size_bar: extracted_forest_size_bar.set_description('extracted_forest_size={}'.format(extracted_forest_size)) sub_models_dir = models_dir + os.sep + 'extracted_forest_size' + os.sep + str(extracted_forest_size)