diff --git a/code/bolsonaro/models/kmeans_forest_regressor.py b/code/bolsonaro/models/kmeans_forest_regressor.py
index 551c6d878b8ff8445c3385493195a450a346c899..4f15372186898ba66792d39ac5c2f8b810e682f9 100644
--- a/code/bolsonaro/models/kmeans_forest_regressor.py
+++ b/code/bolsonaro/models/kmeans_forest_regressor.py
@@ -1,5 +1,6 @@
 import time
 
+from bolsonaro.models.utils import score_metric_mse, score_metric_indicator, aggregation_classification, aggregation_regression
 from bolsonaro.utils import tqdm_joblib
 
 from sklearn.ensemble import RandomForestRegressor
@@ -53,72 +54,85 @@ class KmeansForest(BaseEstimator, metaclass=ABCMeta):
             lst_pruned_forest.append(self._estimator.estimators_[index_trees_cluster[best_tree_index]])
 
         self._selected_trees = lst_pruned_forest
-        self._estimator.estimators_ = lst_pruned_forest
+        # self._estimator.estimators_ = lst_pruned_forest
 
     def score(self, X, y):
-        predictions = np.empty((len(self._estimator.estimators_), X.shape[0]))
-        for idx_tree, tree in enumerate(self._estimator.estimators_):
-            predictions[idx_tree, :] = tree.predict(X)
-        final_predictions = self._aggregate(predictions)
+        final_predictions = self.predict(X)
         score = self._score_metric(final_predictions, y)[0]
         return score
 
     def predict(self, X):
-        return self._estimator.predict(X)
+        predictions = np.empty((len(self._selected_trees), X.shape[0]))
+        for idx_tree, tree in enumerate(self._selected_trees):
+            predictions[idx_tree, :] = tree.predict(X)
+        final_predictions = self._aggregate(predictions)
+        return final_predictions
 
     def predict_base_estimator(self, X):
         return self._estimator.predict(X)
 
+    def _get_best_tree_index(self, y_preds, y_true):
+        score = self._score_metric(y_preds, y_true)
+        best_tree_index = self._best(score)  # get best scoring tree (the one with lowest mse)
+        return best_tree_index
+
     @abstractmethod
     def _score_metric(self, y_preds, y_true):
+        """
+        get score of each predictors in y_preds
+
+        y_preds.shape == (nb_trees, nb_sample)
+        y_true.shape == (1, nb_sample)
+
+        :param y_preds:
+        :param y_true:
+        :return:
+        """
         pass
 
+    @staticmethod
     @abstractmethod
-    def _get_best_tree_index(self, y_preds, y_true):
+    def _best(array):
+        """
+        return index of best element in array
+
+        :param array:
+        :return:
+        """
         pass
 
     @abstractmethod
     def _aggregate(self, predictions):
+        """
+        Aggregates votes of predictors in predictions
+
+        predictions shape: (nb_trees, nb_samples)
+        :param predictions:
+        :return:
+        """
         pass
 
 class KMeansForestRegressor(KmeansForest, metaclass=ABCMeta):
 
     def _aggregate(self, predictions):
-        return np.mean(predictions, axis=0)
+        return aggregation_regression(predictions)
 
     def _score_metric(self, y_preds, y_true):
-        if len(y_true.shape) == 1:
-            y_true = y_true[np.newaxis, :]
-        if len(y_preds.shape) == 1:
-            y_preds = y_preds[np.newaxis, :]
-        assert y_preds.shape[1] == y_true.shape[1], "Number of examples to compare should be the same in y_preds and y_true"
+        return score_metric_mse(y_preds, y_true)
 
-        diff = y_preds - y_true
-        squared_diff = diff ** 2
-        mean_squared_diff = np.mean(squared_diff, axis=1)
-        return mean_squared_diff
+    @staticmethod
+    def _best(array):
+        return np.argmin(array)
 
-    def _get_best_tree_index(self, y_preds, y_true):
-        score = self._score_metric(y_preds, y_true)
-        best_tree_index = np.argmin(score)  # get best scoring tree (the one with lowest mse)
-        return best_tree_index
 
 class KMeansForestClassifier(KmeansForest, metaclass=ABCMeta):
 
     def _aggregate(self, predictions):
-        return np.sign(np.sum(predictions, axis=0))
+        return aggregation_classification(predictions)
 
     def _score_metric(self, y_preds, y_true):
-        if len(y_true.shape) == 1:
-            y_true = y_true[np.newaxis, :]
-        if len(y_preds.shape) == 1:
-            y_preds = y_preds[np.newaxis, :]
-        assert y_preds.shape[1] == y_true.shape[1], "Number of examples to compare should be the same in y_preds and y_true"
+        return score_metric_indicator(y_preds, y_true)
 
-        bool_arr_correct_predictions = y_preds == y_true
-        return np.average(bool_arr_correct_predictions, axis=1)
-
-    def _get_best_tree_index(self, y_preds, y_true):
-        score = self._score_metric(y_preds, y_true)
-        best_tree_index = np.argmax(score)  # get best scoring tree (the one with lowest mse)
-        return best_tree_index
\ No newline at end of file
+    @staticmethod
+    def _best(array):
+        return np.argmax(array)
diff --git a/code/bolsonaro/models/model_factory.py b/code/bolsonaro/models/model_factory.py
index 785d65332045cd6c9db828cbe6810648e4c76d31..07799ceb966e2a40b10e98fdd134fe458674cf8b 100644
--- a/code/bolsonaro/models/model_factory.py
+++ b/code/bolsonaro/models/model_factory.py
@@ -1,7 +1,7 @@
 from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier
 from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
 from bolsonaro.models.model_parameters import ModelParameters
-from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor
+from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor, SimilarityForestClassifier
 from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor, KMeansForestClassifier
 from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor
 from bolsonaro.data.task import Task
@@ -29,6 +29,8 @@ class ModelFactory(object):
                     random_state=model_parameters.seed)
             elif model_parameters.extraction_strategy == 'kmeans':
                 return KMeansForestClassifier(model_parameters)
+            elif model_parameters.extraction_strategy == 'similarity':
+                return SimilarityForestClassifier(model_parameters)
             else:
                 raise ValueError('Invalid extraction strategy')
         elif task == Task.REGRESSION:
diff --git a/code/bolsonaro/models/similarity_forest_regressor.py b/code/bolsonaro/models/similarity_forest_regressor.py
index 32446e1acb31e1f5e512d634df391d538e4319f0..8008bffa97beabe57f1ffe2ab8b05a267464592f 100644
--- a/code/bolsonaro/models/similarity_forest_regressor.py
+++ b/code/bolsonaro/models/similarity_forest_regressor.py
@@ -7,17 +7,22 @@ from abc import abstractmethod, ABCMeta
 import numpy as np
 from tqdm import tqdm
 
+from bolsonaro.models.utils import score_metric_mse, aggregation_regression, aggregation_classification, score_metric_indicator
 
-class SimilarityForestRegressor(BaseEstimator, metaclass=ABCMeta):
+
+class SimilarityForest(BaseEstimator, metaclass=ABCMeta):
     """
     https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2822360/
     """
+    similarity_similarities = "similarity_similarities"
+    similarity_predictions = "similarity_predictions"
 
     def __init__(self, models_parameters):
         self._models_parameters = models_parameters
         self._estimator = RandomForestRegressor(**self._models_parameters.hyperparameters,
-            random_state=self._models_parameters.seed, n_jobs=-1)
+                                                random_state=self._models_parameters.seed, n_jobs=-1)
         self._extracted_forest_size = self._models_parameters.extracted_forest_size
+        self._selected_trees = list()
 
     @property
     def models_parameters(self):
@@ -27,32 +32,20 @@ class SimilarityForestRegressor(BaseEstimator, metaclass=ABCMeta):
     def selected_trees(self):
         return self._selected_trees
 
-    def _score_metric(self, y_preds, y_true):
-        if len(y_true.shape) == 1:
-            y_true = y_true[np.newaxis, :]
-        if len(y_preds.shape) == 1:
-            y_preds = y_preds[np.newaxis, :]
-        assert y_preds.shape[1] == y_true.shape[1], "Number of examples to compare should be the same in y_preds and y_true"
-
-        diff = y_preds - y_true
-        squared_diff = diff ** 2
-        mean_squared_diff = np.mean(squared_diff, axis=1)
-        return mean_squared_diff
+    def predict(self, X):
+        predictions = np.empty((len(self._selected_trees), X.shape[0]))
+        for idx_tree, tree in enumerate(self._selected_trees):
+            predictions[idx_tree, :] = tree.predict(X)
+        final_predictions = self._aggregate(predictions)
+        return final_predictions
 
+    def predict_base_estimator(self, X):
+        return self._estimator.predict(X)
 
     def fit(self, X_train, y_train, X_val, y_val):
         self._estimator.fit(X_train, y_train)
 
-        # param = self._models_parameters.extraction_strategy
-        param = "similarity_predictions"
-
-        #
-        # if param == "similarity_similarities":
-        #     pass
-        # elif param == "similarity_predictions":
-        #     pass
-        # else:
-        #     raise ValueError
+        param = self._models_parameters.extraction_strategy
 
         # get score of base forest on val
         tree_list = list(self._estimator.estimators_)        # get score of base forest on val
@@ -78,7 +71,7 @@ class SimilarityForestRegressor(BaseEstimator, metaclass=ABCMeta):
                 val_predictions_to_consider = val_predictions[idx_trees_to_consider]
                 nb_trees_to_consider = val_predictions_to_consider.shape[0]
 
-                if param == "similarity_predictions":
+                if param == self.similarity_predictions:
                     # this matrix has zero on the diag and 1/(L-1) everywhere else.
                     # When multiplying left the matrix of predictions (having L lines) by this zero_diag_matrix (square L), the result has on each
                     # line, the average of all other lines in the initial matrix of predictions
@@ -86,6 +79,7 @@ class SimilarityForestRegressor(BaseEstimator, metaclass=ABCMeta):
                     np.fill_diagonal(zero_diag_matrix, 0)
 
                     leave_one_tree_out_predictions_val = zero_diag_matrix @ val_predictions_to_consider
+                    leave_one_tree_out_predictions_val = self._activation(leave_one_tree_out_predictions_val)  # identity for regression; sign for classification
                     leave_one_tree_out_scores_val = self._score_metric(leave_one_tree_out_predictions_val, y_val)
                     # difference with base forest is actually useless
                     # delta_score = forest_score - leave_one_tree_out_scores_val
@@ -93,13 +87,16 @@ class SimilarityForestRegressor(BaseEstimator, metaclass=ABCMeta):
                     # get index of tree to remove
                     index_worse_tree = int(np.argmax(leave_one_tree_out_scores_val))  # correlation and MSE: both greater is worse
 
-                elif param == "similarity_similarities":
+                elif param == self.similarity_similarities:
                     correlation_matrix = val_predictions_to_consider @ val_predictions_to_consider.T
                     average_correlation_by_tree = np.average(correlation_matrix, axis=1)
 
                     # get index of tree to remove
                     index_worse_tree = int(np.argmax(average_correlation_by_tree))  # correlation and MSE: both greater is worse
 
+                else:
+                    raise ValueError("Unknown similarity method {}. Should be {} or {}".format(param, self.similarity_similarities, self.similarity_predictions))
+
                 index_worse_tree_in_base_forest = idx_trees_to_consider[index_worse_tree]
                 trees_to_remove.append(tree_list[index_worse_tree_in_base_forest])
                 mask_trees_to_consider[index_worse_tree_in_base_forest] = False
@@ -108,16 +105,50 @@ class SimilarityForestRegressor(BaseEstimator, metaclass=ABCMeta):
         pruned_forest = list(set(tree_list) - set(trees_to_remove))
 
         self._selected_trees = pruned_forest
-        self._estimator.estimators_ = pruned_forest
 
     def score(self, X, y):
-        test_predictions = np.empty((len(self._estimator.estimators_), X.shape[0]))
-        for idx_tree, mod in enumerate(self._estimator.estimators_):
-            test_predictions[idx_tree, :] = mod.predict(X)
-
-        test_mean = np.mean(test_predictions, axis=0)
-        score = self._score_metric(test_mean, y)[0]
+        final_predictions = self.predict(X)
+        score = self._score_metric(final_predictions, y)[0]
         return score
 
-    def predict_base_estimator(self, X):
-        return self._estimator.predict(X)
+    @abstractmethod
+    def _score_metric(self, y_preds, y_true):
+        pass
+
+    @abstractmethod
+    def _aggregate(self, predictions):
+        """
+        Aggregates votes of predictors in predictions
+
+        predictions shape: (nb_trees, nb_samples)
+        :param predictions:
+        :return:
+        """
+        pass
+
+    @abstractmethod
+    def _activation(self, leave_one_tree_out_predictions_val):
+        pass
+
+
+class SimilarityForestRegressor(SimilarityForest, metaclass=ABCMeta):
+    def _aggregate(self, predictions):
+        return aggregation_regression(predictions)
+
+    def _score_metric(self, y_preds, y_true):
+        return score_metric_mse(y_preds, y_true)
+
+    def _activation(self, predictions):
+        return predictions
+
+
+class SimilarityForestClassifier(SimilarityForest, metaclass=ABCMeta):
+    def _aggregate(self, predictions):
+        return aggregation_classification(predictions)
+
+    def _score_metric(self, y_preds, y_true):
+        return score_metric_indicator(y_preds, y_true)
+
+    def _activation(self, predictions):
+        return np.sign(predictions)
+
diff --git a/code/bolsonaro/models/utils.py b/code/bolsonaro/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a054407ea17b88d95328458ad8d86aa6030ff202
--- /dev/null
+++ b/code/bolsonaro/models/utils.py
@@ -0,0 +1,29 @@
+import numpy as np
+
+def score_metric_mse(y_preds, y_true):
+    if len(y_true.shape) == 1:
+        y_true = y_true[np.newaxis, :]
+    if len(y_preds.shape) == 1:
+        y_preds = y_preds[np.newaxis, :]
+    assert y_preds.shape[1] == y_true.shape[1], "Number of examples to compare should be the same in y_preds and y_true"
+
+    diff = y_preds - y_true
+    squared_diff = diff ** 2
+    mean_squared_diff = np.mean(squared_diff, axis=1)
+    return mean_squared_diff
+
+def score_metric_indicator(y_preds, y_true):
+    if len(y_true.shape) == 1:
+        y_true = y_true[np.newaxis, :]
+    if len(y_preds.shape) == 1:
+        y_preds = y_preds[np.newaxis, :]
+    assert y_preds.shape[1] == y_true.shape[1], "Number of examples to compare should be the same in y_preds and y_true"
+
+    bool_arr_correct_predictions = y_preds == y_true
+    return np.average(bool_arr_correct_predictions, axis=1)
+
+def aggregation_classification(predictions):
+    return np.sign(np.sum(predictions, axis=0))
+
+def aggregation_regression(predictions):
+    return np.mean(predictions, axis=0)
diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py
index 49a63c1a004fa16d9dcf43424b96b1d512b2825d..1adb387ca5cf639b8bea16b72d27b46ac190fb14 100644
--- a/code/bolsonaro/trainer.py
+++ b/code/bolsonaro/trainer.py
@@ -1,7 +1,7 @@
 from bolsonaro.models.model_raw_results import ModelRawResults
 from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
 from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier
-from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor
+from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor, SimilarityForestClassifier
 from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor, KMeansForestClassifier
 from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor
 from bolsonaro.error_handling.logger_factory import LoggerFactory
@@ -122,7 +122,7 @@ class Trainer(object):
                 y_pred = np.sign(y_pred)
                 y_pred = np.where(y_pred == 0, 1, y_pred)
             result = self._classification_score_metric(y_true, y_pred)
-        elif type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor, KMeansForestClassifier]:
+        elif type(model) in [SimilarityForestRegressor, SimilarityForestClassifier, KMeansForestRegressor, EnsembleSelectionForestRegressor, KMeansForestClassifier]:
             result = model.score(X, y_true)
         return result
 
@@ -130,7 +130,7 @@ class Trainer(object):
         if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor]:
             y_pred = model.predict_base_estimator(X)
             result = self._base_regression_score_metric(y_true, y_pred)
-        elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
+        elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier, KMeansForestClassifier, SimilarityForestRegressor]:
             y_pred = model.predict_base_estimator(X)
             result = self._base_classification_score_metric(y_true, y_pred)
         elif type(model) == RandomForestClassifier:
@@ -139,8 +139,6 @@ class Trainer(object):
         elif type(model) is RandomForestRegressor:
             y_pred = model.predict(X)
             result = self._base_regression_score_metric(y_true, y_pred)
-        elif type(model) in [ KMeansForestClassifier]:
-            result = model.score(X, y_true)
         return result
 
     def compute_results(self, model, models_dir):