Skip to content
Snippets Groups Projects
Commit 0b74433a authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Merge branch 'master' into 24-non-negative-omp

parents e207ec6f d65fd5bb
Branches
No related tags found
1 merge request!24Resolve "non negative omp"
...@@ -19,7 +19,7 @@ class ModelFactory(object): ...@@ -19,7 +19,7 @@ class ModelFactory(object):
raise ValueError("Unsupported task '{}'".format(task)) raise ValueError("Unsupported task '{}'".format(task))
if task == Task.BINARYCLASSIFICATION: if task == Task.BINARYCLASSIFICATION:
if model_parameters.extraction_strategy in ['omp', 'omp_distillation']: if model_parameters.extraction_strategy in ['omp', 'omp_nn', 'omp_distillation']:
return OmpForestBinaryClassifier(model_parameters) return OmpForestBinaryClassifier(model_parameters)
elif model_parameters.extraction_strategy == 'random': elif model_parameters.extraction_strategy == 'random':
return RandomForestClassifier(**model_parameters.hyperparameters, return RandomForestClassifier(**model_parameters.hyperparameters,
...@@ -36,7 +36,7 @@ class ModelFactory(object): ...@@ -36,7 +36,7 @@ class ModelFactory(object):
else: else:
raise ValueError('Invalid extraction strategy') raise ValueError('Invalid extraction strategy')
elif task == Task.REGRESSION: elif task == Task.REGRESSION:
if model_parameters.extraction_strategy in ['omp', 'omp_distillation']: if model_parameters.extraction_strategy in ['omp', 'omp_nn', 'omp_distillation']:
return OmpForestRegressor(model_parameters) return OmpForestRegressor(model_parameters)
elif model_parameters.extraction_strategy == 'random': elif model_parameters.extraction_strategy == 'random':
return RandomForestRegressor(**model_parameters.hyperparameters, return RandomForestRegressor(**model_parameters.hyperparameters,
...@@ -53,7 +53,7 @@ class ModelFactory(object): ...@@ -53,7 +53,7 @@ class ModelFactory(object):
else: else:
raise ValueError('Invalid extraction strategy') raise ValueError('Invalid extraction strategy')
elif task == Task.MULTICLASSIFICATION: elif task == Task.MULTICLASSIFICATION:
if model_parameters.extraction_strategy in ['omp', 'omp_distillation']: if model_parameters.extraction_strategy in ['omp', 'omp_nn', 'omp_distillation']:
return OmpForestMulticlassClassifier(model_parameters) return OmpForestMulticlassClassifier(model_parameters)
elif model_parameters.extraction_strategy == 'random': elif model_parameters.extraction_strategy == 'random':
return RandomForestClassifier(**model_parameters.hyperparameters, return RandomForestClassifier(**model_parameters.hyperparameters,
......
...@@ -76,14 +76,7 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb ...@@ -76,14 +76,7 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb
pretrained_estimator = None pretrained_estimator = None
pretrained_model_parameters = None pretrained_model_parameters = None
if parameters['extraction_strategy'] != 'none': if parameters['extraction_strategy'] == 'none':
with tqdm_joblib(tqdm(total=len(parameters['extracted_forest_size']), disable=not verbose)) as extracted_forest_size_job_pb:
Parallel(n_jobs=-1)(delayed(extracted_forest_size_job)(extracted_forest_size_job_pb, parameters['extracted_forest_size'][i],
models_dir, seed, parameters, dataset, hyperparameters, experiment_id, trainer,
pretrained_estimator=pretrained_estimator, pretrained_model_parameters=pretrained_model_parameters,
use_distillation=parameters['extraction_strategy'] == 'omp_distillation')
for i in range(len(parameters['extracted_forest_size'])))
else:
forest_size = hyperparameters['n_estimators'] forest_size = hyperparameters['n_estimators']
logger.info('Base forest training with fixed forest size of {}'.format(forest_size)) logger.info('Base forest training with fixed forest size of {}'.format(forest_size))
sub_models_dir = models_dir + os.sep + 'forest_size' + os.sep + str(forest_size) sub_models_dir = models_dir + os.sep + 'forest_size' + os.sep + str(forest_size)
...@@ -118,6 +111,32 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb ...@@ -118,6 +111,32 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb
trainer.init(model, subsets_used=parameters['subsets_used']) trainer.init(model, subsets_used=parameters['subsets_used'])
trainer.train(model) trainer.train(model)
trainer.compute_results(model, sub_models_dir) trainer.compute_results(model, sub_models_dir)
elif parameters['extraction_strategy'] == 'omp_nn':
forest_size = hyperparameters['n_estimators']
model_parameters = ModelParameters(
extracted_forest_size=forest_size,
normalize_D=parameters['normalize_D'],
subsets_used=parameters['subsets_used'],
normalize_weights=parameters['normalize_weights'],
seed=seed,
hyperparameters=hyperparameters,
extraction_strategy=parameters['extraction_strategy']
)
model_parameters.save(sub_models_dir, experiment_id)
model = ModelFactory.build(dataset.task, model_parameters)
trainer.init(model, subsets_used=parameters['subsets_used'])
trainer.train(model)
trainer.compute_results(model, sub_models_dir)
else:
with tqdm_joblib(tqdm(total=len(parameters['extracted_forest_size']), disable=not verbose)) as extracted_forest_size_job_pb:
Parallel(n_jobs=-1)(delayed(extracted_forest_size_job)(extracted_forest_size_job_pb, parameters['extracted_forest_size'][i],
models_dir, seed, parameters, dataset, hyperparameters, experiment_id, trainer,
pretrained_estimator=pretrained_estimator, pretrained_model_parameters=pretrained_model_parameters,
use_distillation=parameters['extraction_strategy'] == 'omp_distillation')
for i in range(len(parameters['extracted_forest_size'])))
logger.info(f'Training done for seed {seed_str}') logger.info(f'Training done for seed {seed_str}')
seed_job_pb.update(1) seed_job_pb.update(1)
...@@ -232,7 +251,7 @@ if __name__ == "__main__": ...@@ -232,7 +251,7 @@ if __name__ == "__main__":
parser.add_argument('--skip_best_hyperparams', action='store_true', default=DEFAULT_SKIP_BEST_HYPERPARAMS, help='Do not use the best hyperparameters if there exist.') parser.add_argument('--skip_best_hyperparams', action='store_true', default=DEFAULT_SKIP_BEST_HYPERPARAMS, help='Do not use the best hyperparameters if there exist.')
parser.add_argument('--save_experiment_configuration', nargs='+', default=None, help='Save the experiment parameters specified in the command line in a file. Args: {{stage_num}} {{name}}') parser.add_argument('--save_experiment_configuration', nargs='+', default=None, help='Save the experiment parameters specified in the command line in a file. Args: {{stage_num}} {{name}}')
parser.add_argument('--job_number', nargs='?', type=int, default=DEFAULT_JOB_NUMBER, help='Specify the number of job used during the parallelisation across seeds.') parser.add_argument('--job_number', nargs='?', type=int, default=DEFAULT_JOB_NUMBER, help='Specify the number of job used during the parallelisation across seeds.')
parser.add_argument('--extraction_strategy', nargs='?', type=str, default=DEFAULT_EXTRACTION_STRATEGY, help='Specify the strategy to apply to extract the trees from the forest. Either omp, random, none, similarity_similarities, similarity_predictions, kmeans, ensemble.') parser.add_argument('--extraction_strategy', nargs='?', type=str, default=DEFAULT_EXTRACTION_STRATEGY, help='Specify the strategy to apply to extract the trees from the forest. Either omp, omp_nn, random, none, similarity_similarities, similarity_predictions, kmeans, ensemble.')
parser.add_argument('--overwrite', action='store_true', default=DEFAULT_OVERWRITE, help='Overwrite the experiment id') parser.add_argument('--overwrite', action='store_true', default=DEFAULT_OVERWRITE, help='Overwrite the experiment id')
args = parser.parse_args() args = parser.parse_args()
...@@ -243,7 +262,7 @@ if __name__ == "__main__": ...@@ -243,7 +262,7 @@ if __name__ == "__main__":
else: else:
parameters = args.__dict__ parameters = args.__dict__
if parameters['extraction_strategy'] not in ['omp', 'omp_distillation', 'random', 'none', 'similarity_similarities', 'similarity_predictions', 'kmeans', 'ensemble']: if parameters['extraction_strategy'] not in ['omp', 'omp_nn', 'omp_distillation', 'random', 'none', 'similarity_similarities', 'similarity_predictions', 'kmeans', 'ensemble']:
raise ValueError('Specified extraction strategy {} is not supported.'.format(parameters['extraction_strategy'])) raise ValueError('Specified extraction strategy {} is not supported.'.format(parameters['extraction_strategy']))
pathlib.Path(parameters['models_dir']).mkdir(parents=True, exist_ok=True) pathlib.Path(parameters['models_dir']).mkdir(parents=True, exist_ok=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment