diff --git a/code/bolsonaro/models/model_factory.py b/code/bolsonaro/models/model_factory.py index 6f1836278675662ac7ef57f8bf98cc4c8284dc26..56e267fb8208bd733a428d43f9b6bfe6d19b16da 100644 --- a/code/bolsonaro/models/model_factory.py +++ b/code/bolsonaro/models/model_factory.py @@ -19,7 +19,7 @@ class ModelFactory(object): raise ValueError("Unsupported task '{}'".format(task)) if task == Task.BINARYCLASSIFICATION: - if model_parameters.extraction_strategy in ['omp', 'omp_distillation']: + if model_parameters.extraction_strategy in ['omp', 'omp_nn', 'omp_distillation']: return OmpForestBinaryClassifier(model_parameters) elif model_parameters.extraction_strategy == 'random': return RandomForestClassifier(**model_parameters.hyperparameters, @@ -36,7 +36,7 @@ class ModelFactory(object): else: raise ValueError('Invalid extraction strategy') elif task == Task.REGRESSION: - if model_parameters.extraction_strategy in ['omp', 'omp_distillation']: + if model_parameters.extraction_strategy in ['omp', 'omp_nn', 'omp_distillation']: return OmpForestRegressor(model_parameters) elif model_parameters.extraction_strategy == 'random': return RandomForestRegressor(**model_parameters.hyperparameters, @@ -53,7 +53,7 @@ class ModelFactory(object): else: raise ValueError('Invalid extraction strategy') elif task == Task.MULTICLASSIFICATION: - if model_parameters.extraction_strategy in ['omp', 'omp_distillation']: + if model_parameters.extraction_strategy in ['omp', 'omp_nn', 'omp_distillation']: return OmpForestMulticlassClassifier(model_parameters) elif model_parameters.extraction_strategy == 'random': return RandomForestClassifier(**model_parameters.hyperparameters, diff --git a/code/train.py b/code/train.py index 10dbf7354837cab803202a8307c671f0def0f274..457e1c405d203e79c4b58f366ef6adfd596d8948 100644 --- a/code/train.py +++ b/code/train.py @@ -76,14 +76,7 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb pretrained_estimator = None pretrained_model_parameters = None - if parameters['extraction_strategy'] != 'none': - with tqdm_joblib(tqdm(total=len(parameters['extracted_forest_size']), disable=not verbose)) as extracted_forest_size_job_pb: - Parallel(n_jobs=-1)(delayed(extracted_forest_size_job)(extracted_forest_size_job_pb, parameters['extracted_forest_size'][i], - models_dir, seed, parameters, dataset, hyperparameters, experiment_id, trainer, - pretrained_estimator=pretrained_estimator, pretrained_model_parameters=pretrained_model_parameters, - use_distillation=parameters['extraction_strategy'] == 'omp_distillation') - for i in range(len(parameters['extracted_forest_size']))) - else: + if parameters['extraction_strategy'] == 'none': forest_size = hyperparameters['n_estimators'] logger.info('Base forest training with fixed forest size of {}'.format(forest_size)) sub_models_dir = models_dir + os.sep + 'forest_size' + os.sep + str(forest_size) @@ -118,6 +111,32 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb trainer.init(model, subsets_used=parameters['subsets_used']) trainer.train(model) trainer.compute_results(model, sub_models_dir) + elif parameters['extraction_strategy'] == 'omp_nn': + forest_size = hyperparameters['n_estimators'] + model_parameters = ModelParameters( + extracted_forest_size=forest_size, + normalize_D=parameters['normalize_D'], + subsets_used=parameters['subsets_used'], + normalize_weights=parameters['normalize_weights'], + seed=seed, + hyperparameters=hyperparameters, + extraction_strategy=parameters['extraction_strategy'] + ) + model_parameters.save(sub_models_dir, experiment_id) + + model = ModelFactory.build(dataset.task, model_parameters) + + trainer.init(model, subsets_used=parameters['subsets_used']) + trainer.train(model) + trainer.compute_results(model, sub_models_dir) + else: + with tqdm_joblib(tqdm(total=len(parameters['extracted_forest_size']), disable=not verbose)) as extracted_forest_size_job_pb: + Parallel(n_jobs=-1)(delayed(extracted_forest_size_job)(extracted_forest_size_job_pb, parameters['extracted_forest_size'][i], + models_dir, seed, parameters, dataset, hyperparameters, experiment_id, trainer, + pretrained_estimator=pretrained_estimator, pretrained_model_parameters=pretrained_model_parameters, + use_distillation=parameters['extraction_strategy'] == 'omp_distillation') + for i in range(len(parameters['extracted_forest_size']))) + logger.info(f'Training done for seed {seed_str}') seed_job_pb.update(1) @@ -232,7 +251,7 @@ if __name__ == "__main__": parser.add_argument('--skip_best_hyperparams', action='store_true', default=DEFAULT_SKIP_BEST_HYPERPARAMS, help='Do not use the best hyperparameters if there exist.') parser.add_argument('--save_experiment_configuration', nargs='+', default=None, help='Save the experiment parameters specified in the command line in a file. Args: {{stage_num}} {{name}}') parser.add_argument('--job_number', nargs='?', type=int, default=DEFAULT_JOB_NUMBER, help='Specify the number of job used during the parallelisation across seeds.') - parser.add_argument('--extraction_strategy', nargs='?', type=str, default=DEFAULT_EXTRACTION_STRATEGY, help='Specify the strategy to apply to extract the trees from the forest. Either omp, random, none, similarity_similarities, similarity_predictions, kmeans, ensemble.') + parser.add_argument('--extraction_strategy', nargs='?', type=str, default=DEFAULT_EXTRACTION_STRATEGY, help='Specify the strategy to apply to extract the trees from the forest. Either omp, omp_nn, random, none, similarity_similarities, similarity_predictions, kmeans, ensemble.') parser.add_argument('--overwrite', action='store_true', default=DEFAULT_OVERWRITE, help='Overwrite the experiment id') args = parser.parse_args() @@ -243,7 +262,7 @@ if __name__ == "__main__": else: parameters = args.__dict__ - if parameters['extraction_strategy'] not in ['omp', 'omp_distillation', 'random', 'none', 'similarity_similarities', 'similarity_predictions', 'kmeans', 'ensemble']: + if parameters['extraction_strategy'] not in ['omp', 'omp_nn', 'omp_distillation', 'random', 'none', 'similarity_similarities', 'similarity_predictions', 'kmeans', 'ensemble']: raise ValueError('Specified extraction strategy {} is not supported.'.format(parameters['extraction_strategy'])) pathlib.Path(parameters['models_dir']).mkdir(parents=True, exist_ok=True)