Skip to content
Snippets Groups Projects
Commit bd349760 authored by Luc Giffon's avatar Luc Giffon
Browse files

fix classification for similarity forest

parent 24cb371b
No related branches found
No related tags found
1 merge request!23Resolve "integration-sota"
import time import time
from sklearn.ensemble import RandomForestRegressor from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.metrics import mean_squared_error from sklearn.metrics import mean_squared_error
from sklearn.base import BaseEstimator from sklearn.base import BaseEstimator
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
...@@ -19,10 +19,15 @@ class SimilarityForest(BaseEstimator, metaclass=ABCMeta): ...@@ -19,10 +19,15 @@ class SimilarityForest(BaseEstimator, metaclass=ABCMeta):
def __init__(self, models_parameters): def __init__(self, models_parameters):
self._models_parameters = models_parameters self._models_parameters = models_parameters
self._estimator = RandomForestRegressor(**self._models_parameters.hyperparameters,
random_state=self._models_parameters.seed, n_jobs=-1)
self._extracted_forest_size = self._models_parameters.extracted_forest_size self._extracted_forest_size = self._models_parameters.extracted_forest_size
self._selected_trees = list() self._selected_trees = list()
self._base_estimator = self.init_estimator(models_parameters)
@staticmethod
@abstractmethod
def init_estimator(model_parameters):
pass
@property @property
def models_parameters(self): def models_parameters(self):
...@@ -32,32 +37,33 @@ class SimilarityForest(BaseEstimator, metaclass=ABCMeta): ...@@ -32,32 +37,33 @@ class SimilarityForest(BaseEstimator, metaclass=ABCMeta):
def selected_trees(self): def selected_trees(self):
return self._selected_trees return self._selected_trees
def _base_estimator_predictions(self, X):
base_predictions = np.array([tree.predict(X) for tree in self._base_estimator.estimators_]).T
return base_predictions
def _selected_tree_predictions(self, X):
base_predictions = np.array([tree.predict(X) for tree in self.selected_trees]).T
return base_predictions
def predict(self, X): def predict(self, X):
predictions = np.empty((len(self._selected_trees), X.shape[0])) predictions = self._selected_tree_predictions(X).T
for idx_tree, tree in enumerate(self._selected_trees):
predictions[idx_tree, :] = tree.predict(X)
final_predictions = self._aggregate(predictions) final_predictions = self._aggregate(predictions)
return final_predictions return final_predictions
def predict_base_estimator(self, X): def predict_base_estimator(self, X):
return self._estimator.predict(X) return self._base_estimator.predict(X)
def fit(self, X_train, y_train, X_val, y_val): def fit(self, X_train, y_train, X_val, y_val):
self._estimator.fit(X_train, y_train) self._base_estimator.fit(X_train, y_train)
param = self._models_parameters.extraction_strategy param = self._models_parameters.extraction_strategy
# get score of base forest on val # get score of base forest on val
tree_list = list(self._estimator.estimators_) # get score of base forest on val tree_list = list(self._base_estimator.estimators_) # get score of base forest on val
trees_to_remove = list() trees_to_remove = list()
# get score of each single tree of forest on val # get score of each single tree of forest on val
val_predictions = np.empty((len(tree_list), X_val.shape[0])) val_predictions = self._base_estimator_predictions(X_val).T
with tqdm(tree_list) as tree_pred_bar:
tree_pred_bar.set_description('[Initial tree predictions]')
for idx_tree, tree in enumerate(tree_pred_bar):
val_predictions[idx_tree, :] = tree.predict(X_val)
tree_pred_bar.update(1)
# boolean mask of trees to take into account for next evaluation of trees importance # boolean mask of trees to take into account for next evaluation of trees importance
mask_trees_to_consider = np.ones(val_predictions.shape[0], dtype=bool) mask_trees_to_consider = np.ones(val_predictions.shape[0], dtype=bool)
...@@ -132,6 +138,12 @@ class SimilarityForest(BaseEstimator, metaclass=ABCMeta): ...@@ -132,6 +138,12 @@ class SimilarityForest(BaseEstimator, metaclass=ABCMeta):
class SimilarityForestRegressor(SimilarityForest, metaclass=ABCMeta): class SimilarityForestRegressor(SimilarityForest, metaclass=ABCMeta):
@staticmethod
def init_estimator(model_parameters):
return RandomForestRegressor(**model_parameters.hyperparameters,
random_state=model_parameters.seed, n_jobs=2)
def _aggregate(self, predictions): def _aggregate(self, predictions):
return aggregation_regression(predictions) return aggregation_regression(predictions)
...@@ -143,6 +155,12 @@ class SimilarityForestRegressor(SimilarityForest, metaclass=ABCMeta): ...@@ -143,6 +155,12 @@ class SimilarityForestRegressor(SimilarityForest, metaclass=ABCMeta):
class SimilarityForestClassifier(SimilarityForest, metaclass=ABCMeta): class SimilarityForestClassifier(SimilarityForest, metaclass=ABCMeta):
@staticmethod
def init_estimator(model_parameters):
return RandomForestClassifier(**model_parameters.hyperparameters,
random_state=model_parameters.seed, n_jobs=2)
def _aggregate(self, predictions): def _aggregate(self, predictions):
return aggregation_classification(predictions) return aggregation_classification(predictions)
...@@ -152,3 +170,12 @@ class SimilarityForestClassifier(SimilarityForest, metaclass=ABCMeta): ...@@ -152,3 +170,12 @@ class SimilarityForestClassifier(SimilarityForest, metaclass=ABCMeta):
def _activation(self, predictions): def _activation(self, predictions):
return np.sign(predictions) return np.sign(predictions)
def _selected_tree_predictions(self, X):
predictions_0_1 = super()._selected_tree_predictions(X)
predictions = (predictions_0_1 - 0.5) * 2
return predictions
def _base_estimator_predictions(self, X):
predictions_0_1 = super()._base_estimator_predictions(X)
predictions = (predictions_0_1 - 0.5) * 2
return predictions
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment