diff --git a/code/train.py b/code/train.py index 34d2f860aa46088e38152ae7efb04efb8580bd1c..a410e53bb4ccefc90f5f976617da37e8b709cb40 100644 --- a/code/train.py +++ b/code/train.py @@ -12,7 +12,7 @@ import argparse import pathlib import random import os -import errno +from tqdm import tqdm if __name__ == "__main__": @@ -67,51 +67,42 @@ if __name__ == "__main__": experiment_id = resolve_experiment_id(args.models_dir) experiment_id_str = str(experiment_id) - for seed in seeds: - logger.debug('Seed={}'.format(seed)) - seed_str = str(seed) - models_dir = args.models_dir + os.sep + experiment_id_str + os.sep + 'seeds' + \ - os.sep + seed_str - try: - os.makedirs(models_dir) - except OSError as e: - if e.errno != errno.EEXIST: - raise - - dataset_parameters = DatasetParameters( - name=args.dataset_name, - test_size=args.test_size, - dev_size=args.dev_size, - random_state=seed, - normalize=normalize, - train_on_subset=args.train_on_subset - ) - dataset_parameters.save(models_dir, experiment_id_str) - - dataset = DatasetLoader.load(dataset_parameters) - - trainer = Trainer(dataset) - - for extracted_forest_size in args.extracted_forest_size: - logger.debug('extracted_forest_size={}'.format(extracted_forest_size)) - sub_models_dir = models_dir + os.sep + 'extracted_forest_size' + os.sep + str(extracted_forest_size) - try: - os.makedirs(sub_models_dir) - except OSError as e: - if e.errno != errno.EEXIST: - raise - - model_parameters = ModelParameters( - forest_size=args.forest_size, - extracted_forest_size=extracted_forest_size, - seed=seed, - normalize=normalize + with tqdm(seeds) as seed_bar: + for seed in seed_bar: + seed_bar.set_description('seed={}'.format(seed)) + seed_str = str(seed) + models_dir = args.models_dir + os.sep + experiment_id_str + os.sep + 'seeds' + \ + os.sep + seed_str + pathlib.Path(models_dir).mkdir(parents=True, exist_ok=True) + + dataset_parameters = DatasetParameters( + name=args.dataset_name, + test_size=args.test_size, + dev_size=args.dev_size, + random_state=seed, + normalize=normalize, + train_on_subset=args.train_on_subset ) - model_parameters.save(sub_models_dir, experiment_id) + dataset_parameters.save(models_dir, experiment_id_str) - model = ModelFactory.build(dataset.task, model_parameters) + dataset = DatasetLoader.load(dataset_parameters) - trainer.train(model, sub_models_dir) + trainer = Trainer(dataset) - logger.info('Error on test set: {}'.format(model.score(dataset.X_test, dataset.y_test))) - logger.info('Accuracy on test set: {}'.format(model.score_regressor(dataset.X_test, dataset.y_test))) + with tqdm(args.extracted_forest_size) as extracted_forest_size_bar: + for extracted_forest_size in extracted_forest_size_bar: + extracted_forest_size_bar.set_description('extracted_forest_size={}'.format(extracted_forest_size)) + sub_models_dir = models_dir + os.sep + 'extracted_forest_size' + os.sep + str(extracted_forest_size) + pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True) + + model_parameters = ModelParameters( + forest_size=args.forest_size, + extracted_forest_size=extracted_forest_size, + seed=seed, + normalize=normalize + ) + model_parameters.save(sub_models_dir, experiment_id) + + model = ModelFactory.build(dataset.task, model_parameters) + + trainer.train(model, sub_models_dir) diff --git a/requirements.txt b/requirements.txt index 3b11a342f72cb6c2a1a437124bb81c48da9e3c3c..8976e07d33247b3cf69d2670742f17e4d50d5c83 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,6 @@ coverage awscli flake8 python-dotenv>=0.5.1 -scikit-learn \ No newline at end of file +scikit-learn +python-dotenv +tqdm \ No newline at end of file