Skip to content
Snippets Groups Projects
Commit bd349760 authored by Luc Giffon's avatar Luc Giffon
Browse files

fix classification for similarity forest

parent 24cb371b
Branches
No related tags found
1 merge request!23Resolve "integration-sota"
This commit is part of merge request !23. Comments created here will be created in the context of that merge request.
import time
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.metrics import mean_squared_error
from sklearn.base import BaseEstimator
from abc import abstractmethod, ABCMeta
......@@ -19,10 +19,15 @@ class SimilarityForest(BaseEstimator, metaclass=ABCMeta):
def __init__(self, models_parameters):
self._models_parameters = models_parameters
self._estimator = RandomForestRegressor(**self._models_parameters.hyperparameters,
random_state=self._models_parameters.seed, n_jobs=-1)
self._extracted_forest_size = self._models_parameters.extracted_forest_size
self._selected_trees = list()
self._base_estimator = self.init_estimator(models_parameters)
@staticmethod
@abstractmethod
def init_estimator(model_parameters):
pass
@property
def models_parameters(self):
......@@ -32,32 +37,33 @@ class SimilarityForest(BaseEstimator, metaclass=ABCMeta):
def selected_trees(self):
return self._selected_trees
def _base_estimator_predictions(self, X):
base_predictions = np.array([tree.predict(X) for tree in self._base_estimator.estimators_]).T
return base_predictions
def _selected_tree_predictions(self, X):
base_predictions = np.array([tree.predict(X) for tree in self.selected_trees]).T
return base_predictions
def predict(self, X):
predictions = np.empty((len(self._selected_trees), X.shape[0]))
for idx_tree, tree in enumerate(self._selected_trees):
predictions[idx_tree, :] = tree.predict(X)
predictions = self._selected_tree_predictions(X).T
final_predictions = self._aggregate(predictions)
return final_predictions
def predict_base_estimator(self, X):
return self._estimator.predict(X)
return self._base_estimator.predict(X)
def fit(self, X_train, y_train, X_val, y_val):
self._estimator.fit(X_train, y_train)
self._base_estimator.fit(X_train, y_train)
param = self._models_parameters.extraction_strategy
# get score of base forest on val
tree_list = list(self._estimator.estimators_) # get score of base forest on val
tree_list = list(self._base_estimator.estimators_) # get score of base forest on val
trees_to_remove = list()
# get score of each single tree of forest on val
val_predictions = np.empty((len(tree_list), X_val.shape[0]))
with tqdm(tree_list) as tree_pred_bar:
tree_pred_bar.set_description('[Initial tree predictions]')
for idx_tree, tree in enumerate(tree_pred_bar):
val_predictions[idx_tree, :] = tree.predict(X_val)
tree_pred_bar.update(1)
val_predictions = self._base_estimator_predictions(X_val).T
# boolean mask of trees to take into account for next evaluation of trees importance
mask_trees_to_consider = np.ones(val_predictions.shape[0], dtype=bool)
......@@ -132,6 +138,12 @@ class SimilarityForest(BaseEstimator, metaclass=ABCMeta):
class SimilarityForestRegressor(SimilarityForest, metaclass=ABCMeta):
@staticmethod
def init_estimator(model_parameters):
return RandomForestRegressor(**model_parameters.hyperparameters,
random_state=model_parameters.seed, n_jobs=2)
def _aggregate(self, predictions):
return aggregation_regression(predictions)
......@@ -143,6 +155,12 @@ class SimilarityForestRegressor(SimilarityForest, metaclass=ABCMeta):
class SimilarityForestClassifier(SimilarityForest, metaclass=ABCMeta):
@staticmethod
def init_estimator(model_parameters):
return RandomForestClassifier(**model_parameters.hyperparameters,
random_state=model_parameters.seed, n_jobs=2)
def _aggregate(self, predictions):
return aggregation_classification(predictions)
......@@ -152,3 +170,12 @@ class SimilarityForestClassifier(SimilarityForest, metaclass=ABCMeta):
def _activation(self, predictions):
return np.sign(predictions)
def _selected_tree_predictions(self, X):
predictions_0_1 = super()._selected_tree_predictions(X)
predictions = (predictions_0_1 - 0.5) * 2
return predictions
def _base_estimator_predictions(self, X):
predictions_0_1 = super()._base_estimator_predictions(X)
predictions = (predictions_0_1 - 0.5) * 2
return predictions
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment