Skip to content
Snippets Groups Projects
Commit f077de3c authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Merge branch 'master' into 19-add-some-tests

parents 5cf23c9b c04845a9
No related branches found
No related tags found
1 merge request!21Resolve "Add some tests"
Showing
with 130 additions and 41 deletions
......@@ -16,6 +16,7 @@ class EnsembleSelectionForestRegressor(BaseEstimator, metaclass=ABCMeta):
self._library = library
self._extracted_forest_size = self._models_parameters.extracted_forest_size
self._score_metric = score_metric
self._selected_trees = list()
@property
def models_parameters(self):
......@@ -25,6 +26,10 @@ class EnsembleSelectionForestRegressor(BaseEstimator, metaclass=ABCMeta):
def library(self):
return self._library
@property
def selected_trees(self):
return self._selected_trees
def fit(self, X_train, y_train, X_val, y_val):
scores_list = list()
for estimator in self._library:
......@@ -33,7 +38,7 @@ class EnsembleSelectionForestRegressor(BaseEstimator, metaclass=ABCMeta):
class_list = list(self._library)
m = np.argmax(np.asarray(scores_list))
self._ensemble_selected = [class_list[m]]
self._selected_trees = [class_list[m]]
temp_pred = class_list[m].predict(X_val)
del class_list[m]
for k in range(self._extracted_forest_size - 1):
......@@ -47,17 +52,17 @@ class EnsembleSelectionForestRegressor(BaseEstimator, metaclass=ABCMeta):
candidate_index = j
best_score = temp_score
temp_pred = np.delete(temp_pred, -1, 0)
self._ensemble_selected.append(class_list[candidate_index])
self._selected_trees.append(class_list[candidate_index])
temp_pred = np.vstack((temp_pred, class_list[candidate_index].predict(X_val)))
del class_list[candidate_index]
def score(self, X, y):
predictions = self._predict_base_estimator(X)
predictions = self.predict_base_estimator(X)
return self._score_metric(predictions, y)
def predict_base_estimator(self, X):
predictions = list()
for tree in self._ensemble_selected:
for tree in self._selected_trees:
predictions.append(tree.predict(X))
mean_predictions = np.mean(np.array(predictions), axis=0)
return mean_predictions
......
......@@ -22,11 +22,16 @@ class KMeansForestRegressor(BaseEstimator, metaclass=ABCMeta):
random_state=self._models_parameters.seed, n_jobs=-1)
self._extracted_forest_size = self._models_parameters.extracted_forest_size
self._score_metric = score_metric
self._selected_trees = list()
@property
def models_parameters(self):
return self._models_parameters
@property
def selected_trees(self):
return self._selected_trees
def fit(self, X_train, y_train, X_val, y_val):
self._estimator.fit(X_train, y_train)
......@@ -45,6 +50,7 @@ class KMeansForestRegressor(BaseEstimator, metaclass=ABCMeta):
extracted_forest_sizes[i], labels, X_val, y_val, self._score_metric)
for i in range(self._extracted_forest_size))
self._selected_trees = pruned_forest
self._estimator.estimators_ = pruned_forest
def _prune_forest_job(self, prune_forest_job_pb, c, labels, X_val, y_val, score_metric):
......
from bolsonaro import LOG_PATH
from bolsonaro.error_handling.logger_factory import LoggerFactory
from bolsonaro.utils import omp_premature_warning
from abc import abstractmethod, ABCMeta
import numpy as np
from sklearn.linear_model import OrthogonalMatchingPursuit
from sklearn.base import BaseEstimator
import warnings
class OmpForest(BaseEstimator, metaclass=ABCMeta):
......@@ -13,6 +15,7 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta):
self._base_forest_estimator = base_forest_estimator
self._models_parameters = models_parameters
self._logger = LoggerFactory.create(LOG_PATH, __name__)
self._selected_trees = list()
@property
def models_parameters(self):
......@@ -109,8 +112,18 @@ class SingleOmpForest(OmpForest):
super().__init__(models_parameters, base_forest_estimator)
def fit_omp(self, atoms, objective):
with warnings.catch_warnings(record=True) as caught_warnings:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
self._omp.fit(atoms, objective)
# ignore any non-custom warnings that may be in the list
caught_warnings = list(filter(lambda i: i.message != RuntimeWarning(omp_premature_warning), caught_warnings))
if len(caught_warnings) > 0:
logger.error(f'number of linear dependences in the dictionary: {len(caught_warnings)}. model parameters: {str(self._models_parameters.__dict__)}')
def predict(self, X):
"""
Apply the SingleOmpForest to X.
......@@ -123,9 +136,7 @@ class SingleOmpForest(OmpForest):
forest_predictions = self._base_estimator_predictions(X)
if self._models_parameters.normalize_D:
forest_predictions = forest_predictions.T
forest_predictions /= self._forest_norms
forest_predictions = forest_predictions.T
return self._make_omp_weighted_prediction(forest_predictions, self._omp, self._models_parameters.normalize_weights)
......@@ -136,7 +147,7 @@ class SingleOmpForest(OmpForest):
Make all the base tree predictions
:param X: a Forest
:return: a np.array of the predictions of the trees selected by OMP without applyong the weight
:return: a np.array of the predictions of the trees selected by OMP without applying the weight
"""
forest_predictions = np.array([tree.predict(X) for tree in self._base_forest_estimator.estimators_])
......
from bolsonaro.models.omp_forest import OmpForest, SingleOmpForest
from bolsonaro.utils import binarize_class_data
from bolsonaro.utils import binarize_class_data, omp_premature_warning
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import OrthogonalMatchingPursuit
import warnings
class OmpForestBinaryClassifier(SingleOmpForest):
......@@ -92,7 +93,19 @@ class OmpForestMulticlassClassifier(OmpForest):
omp_class = OrthogonalMatchingPursuit(
n_nonzero_coefs=self.models_parameters.extracted_forest_size,
fit_intercept=True, normalize=False)
with warnings.catch_warnings(record=True) as caught_warnings:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
omp_class.fit(atoms_binary, objective_binary)
# ignore any non-custom warnings that may be in the list
caught_warnings = list(filter(lambda i: i.message != RuntimeWarning(omp_premature_warning), caught_warnings))
if len(caught_warnings) > 0:
logger.error(f'number of linear dependences in the dictionary: {len(caught_warnings)}. model parameters: {str(self._models_parameters.__dict__)}')
self._dct_class_omp[class_label] = omp_class
return self._dct_class_omp
......@@ -119,9 +132,7 @@ class OmpForestMulticlassClassifier(OmpForest):
forest_predictions = np.array([tree.predict_proba(X) for tree in self._base_forest_estimator.estimators_]).T
if self._models_parameters.normalize_D:
forest_predictions = forest_predictions.T
forest_predictions /= self._forest_norms
forest_predictions = forest_predictions.T
label_names = []
preds = []
......@@ -149,7 +160,9 @@ class OmpForestMulticlassClassifier(OmpForest):
forest_predictions = np.array([tree.predict_proba(X) for tree in self._base_forest_estimator.estimators_]).T
if self._models_parameters.normalize_D:
forest_predictions = forest_predictions.T
forest_predictions /= self._forest_norms
forest_predictions = forest_predictions.T
label_names = []
preds = []
......
......@@ -17,18 +17,22 @@ class SimilarityForestRegressor(BaseEstimator, metaclass=ABCMeta):
random_state=self._models_parameters.seed, n_jobs=-1)
self._extracted_forest_size = self._models_parameters.extracted_forest_size
self._score_metric = score_metric
self._selected_trees = list()
@property
def models_parameters(self):
return self._models_parameters
@property
def selected_trees(self):
return self._selected_trees
def fit(self, X_train, y_train, X_val, y_val):
self._estimator.fit(X_train, y_train)
y_val_pred = self._estimator.predict(X_val)
forest_pred = self._score_metric(y_val, y_val_pred)
forest = self._estimator.estimators_
selected_trees = list()
tree_list = list(self._estimator.estimators_)
val_scores = list()
......@@ -57,12 +61,13 @@ class SimilarityForestRegressor(BaseEstimator, metaclass=ABCMeta):
tree_list.insert(j, lonely_tree)
val_scores.insert(j, lonely_tree.predict(X_val))
tree_list_bar.update(1)
selected_trees.append(tree_list[found_index])
self._selected_trees.append(tree_list[found_index])
del tree_list[found_index]
del val_scores[found_index]
pruning_forest_bar.update(1)
pruned_forest = list(set(forest) - set(selected_trees))
self._selected_trees = set(self._selected_trees)
pruned_forest = list(set(forest) - self._selected_trees)
self._estimator.estimators_ = pruned_forest
def score(self, X, y):
......
......@@ -13,6 +13,8 @@ from sklearn.metrics import mean_squared_error, accuracy_score
import time
import datetime
import numpy as np
import os
import pickle
class Trainer(object):
......@@ -36,6 +38,7 @@ class Trainer(object):
else classification_score_metric.__name__
self._base_score_metric_name = base_regression_score_metric.__name__ if dataset.task == Task.REGRESSION \
else base_classification_score_metric.__name__
self._selected_trees = ''
@property
def score_metric_name(self):
......@@ -93,6 +96,7 @@ class Trainer(object):
X=self._X_forest,
y=self._y_forest
)
self._selected_trees = model.estimators_
else:
model.fit(
self._X_forest,
......@@ -151,6 +155,13 @@ class Trainer(object):
elif type(model) == OmpForestBinaryClassifier:
model_weights = model._omp
if type(model) in [SimilarityForestRegressor, EnsembleSelectionForestRegressor, KMeansForestRegressor]:
self._selected_trees = model.selected_trees
if len(self._selected_trees) > 0:
with open(os.path.join(models_dir, 'selected_trees.pickle'), 'wb') as output_file:
pickle.dump(self._selected_trees, output_file)
results = ModelRawResults(
model_weights=model_weights,
training_time=self._end_time - self._begin_time,
......
......@@ -124,3 +124,7 @@ def is_float(value):
return True
except ValueError:
return False
omp_premature_warning = """ Orthogonal matching pursuit ended prematurely due to linear
dependence in the dictionary. The requested precision might not have been met.
"""
......@@ -156,6 +156,7 @@ if __name__ == "__main__":
DEFAULT_RESULTS_DIR = os.environ["project_dir"] + os.sep + 'results'
DEFAULT_MODELS_DIR = os.environ["project_dir"] + os.sep + 'models'
DEFAULT_PLOT_WEIGHT_DENSITY = False
DEFAULT_WO_LOSS_PLOTS = False
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--stage', nargs='?', type=int, required=True, help='Specify the stage number among [1, 5].')
......@@ -168,6 +169,7 @@ if __name__ == "__main__":
parser.add_argument('--results_dir', nargs='?', type=str, default=DEFAULT_RESULTS_DIR, help='The output directory of the results.')
parser.add_argument('--models_dir', nargs='?', type=str, default=DEFAULT_MODELS_DIR, help='The output directory of the trained models.')
parser.add_argument('--plot_weight_density', action='store_true', default=DEFAULT_PLOT_WEIGHT_DENSITY, help='Plot the weight density. Only working for regressor models for now.')
parser.add_argument('--wo_loss_plots', action='store_true', default=DEFAULT_WO_LOSS_PLOTS, help='Do not compute the loss plots.')
args = parser.parse_args()
if args.stage not in list(range(1, 6)):
......@@ -181,7 +183,7 @@ if __name__ == "__main__":
# Create recursively the results dir tree
pathlib.Path(args.results_dir).mkdir(parents=True, exist_ok=True)
if args.stage == 1:
if args.stage == 1 and not args.wo_loss_plots:
if len(args.experiment_ids) != 6:
raise ValueError('In the case of stage 1, the number of specified experiment ids must be 6.')
......@@ -221,8 +223,8 @@ if __name__ == "__main__":
wo_params_extracted_forest_sizes, random_wo_params_experiment_score_metric = \
extract_scores_across_seeds_and_extracted_forest_sizes(
args.models_dir, args.results_dir, int(args.experiment_ids[4]))
# base_wo_params
logger.info('Loading base_wo_params experiment scores...')
# omp_wo_params
logger.info('Loading omp_wo_params experiment scores...')
omp_wo_params_train_scores, omp_wo_params_dev_scores, omp_wo_params_test_scores, _, \
omp_wo_params_experiment_score_metric = extract_scores_across_seeds_and_extracted_forest_sizes(
args.models_dir, args.results_dir, int(args.experiment_ids[5]))
......@@ -262,7 +264,7 @@ if __name__ == "__main__":
ylabel=experiments_score_metric,
title='Loss values of {}\nusing best and default hyperparameters'.format(args.dataset_name)
)
elif args.stage == 2:
elif args.stage == 2 and not args.wo_loss_plots:
if len(args.experiment_ids) != 4:
raise ValueError('In the case of stage 2, the number of specified experiment ids must be 4.')
......@@ -308,7 +310,7 @@ if __name__ == "__main__":
xlabel='Number of trees extracted',
ylabel=experiments_score_metric,
title='Loss values of {}\nusing different normalizations'.format(args.dataset_name))
elif args.stage == 3:
elif args.stage == 3 and not args.wo_loss_plots:
if len(args.experiment_ids) != 3:
raise ValueError('In the case of stage 3, the number of specified experiment ids must be 3.')
......@@ -365,7 +367,7 @@ if __name__ == "__main__":
xlabel='Number of trees extracted',
ylabel=experiments_score_metric,
title='Loss values of {}\nusing different training subsets'.format(args.dataset_name))"""
elif args.stage == 4:
elif args.stage == 4 and not args.wo_loss_plots:
if len(args.experiment_ids) != 3:
raise ValueError('In the case of stage 4, the number of specified experiment ids must be 3.')
......@@ -427,11 +429,7 @@ if __name__ == "__main__":
xlabel='Number of trees extracted',
ylabel=experiments_score_metric,
title='Loss values of {}\nusing best params of previous stages'.format(args.dataset_name))
experiment_weights = extract_weights_across_seeds(args.models_dir, args.results_dir, args.experiment_ids[2])
Plotter.weight_density(experiment_weights, os.path.join(output_path, 'weight_density.png'))
elif args.stage == 5:
elif args.stage == 5 and not args.wo_loss_plots:
# Retreive the extracted forest sizes number used in order to have a base forest axis as long as necessary
extracted_forest_sizes_number = retreive_extracted_forest_sizes_number(args.models_dir, int(args.experiment_ids[1]))
all_labels = list()
......@@ -475,8 +473,9 @@ if __name__ == "__main__":
continue
logger.info(f'Loading {label} experiment scores...')
current_experiment_id = int(args.experiment_ids[i].split('=')[1])
_, _, current_test_scores, _, _ = extract_scores_across_seeds_and_extracted_forest_sizes(
args.models_dir, args.results_dir, int(args.experiment_ids[i].split('=')[1]))
args.models_dir, args.results_dir, current_experiment_id)
all_labels.append(label)
all_scores.append(current_test_scores)
......@@ -491,7 +490,42 @@ if __name__ == "__main__":
xlabel='Number of trees extracted',
ylabel=base_with_params_experiment_score_metric,
title='Loss values of {}\nusing best params of previous stages'.format(args.dataset_name))
if args.plot_weight_density:
root_output_path = os.path.join(args.results_dir, args.dataset_name, f'stage{args.stage}')
if args.stage == 1:
omp_experiment_ids = [('omp_with_params', args.experiment_ids[2]), ('omp_wo_params', args.experiment_ids[2])]
elif args.stage == 2:
omp_experiment_ids = [('no_normalization', args.experiment_ids[0]),
('normalize_D', args.experiment_ids[1]),
('normalize_weights', args.experiment_ids[2]),
('normalize_D_and_weights', args.experiment_ids[3])]
elif args.stage == 3:
omp_experiment_ids = [('train-dev_subset', args.experiment_ids[0]),
('train-dev_train-dev_subset', args.experiment_ids[1]),
('train-train-dev_subset', args.experiment_ids[2])]
elif args.stage == 4:
omp_experiment_ids = [('omp_with_params', args.experiment_ids[2])]
elif args.stage == 5:
omp_experiment_ids = [('omp_with_params', args.experiment_ids[2])]
for i in range(3, len(args.experiment_ids)):
if 'kmeans' in args.experiment_ids[i]:
label = 'kmeans'
elif 'similarity' in args.experiment_ids[i]:
label = 'similarity'
elif 'ensemble' in args.experiment_ids[i]:
label = 'ensemble'
else:
raise ValueError('This stage number is not supported yet, but it will be!')
logger.error('Invalid value encountered')
continue
current_experiment_id = int(args.experiment_ids[i].split('=')[1])
omp_experiment_ids.append((label, current_experiment_id))
for (experiment_label, experiment_id) in omp_experiment_ids:
logger.info(f'Computing weight density plot for experiment {experiment_label}...')
experiment_weights = extract_weights_across_seeds(args.models_dir, args.results_dir, experiment_id)
Plotter.weight_density(experiment_weights, os.path.join(root_output_path, f'weight_density_{experiment_label}.png'))
logger.info('Done.')
......@@ -141,7 +141,7 @@ def extracted_forest_size_job(extracted_forest_size_job_pb, extracted_forest_siz
sub_models_dir_files = os.listdir(sub_models_dir)
for file_name in sub_models_dir_files:
if '.pickle' != os.path.splitext(file_name)[1]:
return
continue
else:
already_exists = os.path.getsize(os.path.join(sub_models_dir, file_name)) > 0
break
......
......@@ -29,7 +29,7 @@
],
"job_number": -1,
"extraction_strategy": "omp",
"overwrite": false,
"overwrite": true,
"extracted_forest_size": [
33,
67,
......
......@@ -29,7 +29,7 @@
],
"job_number": -1,
"extraction_strategy": "omp",
"overwrite": false,
"overwrite": true,
"extracted_forest_size": [
33,
67,
......
......@@ -29,7 +29,7 @@
],
"job_number": -1,
"extraction_strategy": "omp",
"overwrite": false,
"overwrite": true,
"extracted_forest_size": [
33,
67,
......
......@@ -29,7 +29,7 @@
],
"job_number": -1,
"extraction_strategy": "omp",
"overwrite": false,
"overwrite": true,
"extracted_forest_size": [
33,
67,
......
......@@ -29,7 +29,7 @@
],
"job_number": -1,
"extraction_strategy": "omp",
"overwrite": false,
"overwrite": true,
"extracted_forest_size": [
33,
67,
......
......@@ -29,7 +29,7 @@
],
"job_number": -1,
"extraction_strategy": "omp",
"overwrite": false,
"overwrite": true,
"extracted_forest_size": [
33,
67,
......
......@@ -29,7 +29,7 @@
],
"job_number": -1,
"extraction_strategy": "omp",
"overwrite": false,
"overwrite": true,
"extracted_forest_size": [
4,
7,
......
......@@ -29,7 +29,7 @@
],
"job_number": -1,
"extraction_strategy": "omp",
"overwrite": false,
"overwrite": true,
"extracted_forest_size": [
4,
7,
......
......@@ -29,7 +29,7 @@
],
"job_number": -1,
"extraction_strategy": "omp",
"overwrite": false,
"overwrite": true,
"extracted_forest_size": [
14,
29,
......
......@@ -29,7 +29,7 @@
],
"job_number": -1,
"extraction_strategy": "omp",
"overwrite": false,
"overwrite": true,
"extracted_forest_size": [
14,
29,
......
......@@ -29,7 +29,7 @@
],
"job_number": -1,
"extraction_strategy": "omp",
"overwrite": false,
"overwrite": true,
"extracted_forest_size": [
33,
67,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment