diff --git a/code/bolsonaro/models/similarity_forest_regressor.py b/code/bolsonaro/models/similarity_forest_regressor.py
index 368d3741e49224c881dfe7a6de38ceec81a4f156..32446e1acb31e1f5e512d634df391d538e4319f0 100644
--- a/code/bolsonaro/models/similarity_forest_regressor.py
+++ b/code/bolsonaro/models/similarity_forest_regressor.py
@@ -1,3 +1,5 @@
+import time
+
 from sklearn.ensemble import RandomForestRegressor
 from sklearn.metrics import mean_squared_error
 from sklearn.base import BaseEstimator
@@ -11,13 +13,11 @@ class SimilarityForestRegressor(BaseEstimator, metaclass=ABCMeta):
     https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2822360/
     """
 
-    def __init__(self, models_parameters, score_metric=mean_squared_error):
+    def __init__(self, models_parameters):
         self._models_parameters = models_parameters
         self._estimator = RandomForestRegressor(**self._models_parameters.hyperparameters,
             random_state=self._models_parameters.seed, n_jobs=-1)
         self._extracted_forest_size = self._models_parameters.extracted_forest_size
-        self._score_metric = score_metric
-        self._selected_trees = list()
 
     @property
     def models_parameters(self):
@@ -27,57 +27,96 @@ class SimilarityForestRegressor(BaseEstimator, metaclass=ABCMeta):
     def selected_trees(self):
         return self._selected_trees
 
+    def _score_metric(self, y_preds, y_true):
+        if len(y_true.shape) == 1:
+            y_true = y_true[np.newaxis, :]
+        if len(y_preds.shape) == 1:
+            y_preds = y_preds[np.newaxis, :]
+        assert y_preds.shape[1] == y_true.shape[1], "Number of examples to compare should be the same in y_preds and y_true"
+
+        diff = y_preds - y_true
+        squared_diff = diff ** 2
+        mean_squared_diff = np.mean(squared_diff, axis=1)
+        return mean_squared_diff
+
+
     def fit(self, X_train, y_train, X_val, y_val):
         self._estimator.fit(X_train, y_train)
 
-        y_val_pred = self._estimator.predict(X_val)
-        forest_pred = self._score_metric(y_val, y_val_pred)
-        forest = self._estimator.estimators_
-        tree_list = list(self._estimator.estimators_)
+        # param = self._models_parameters.extraction_strategy
+        param = "similarity_predictions"
 
-        val_scores = list()
+        #
+        # if param == "similarity_similarities":
+        #     pass
+        # elif param == "similarity_predictions":
+        #     pass
+        # else:
+        #     raise ValueError
+
+        # get score of base forest on val
+        tree_list = list(self._estimator.estimators_)        # get score of base forest on val
+        trees_to_remove = list()
+
+        # get score of each single tree of forest on val
+        val_predictions = np.empty((len(tree_list), X_val.shape[0]))
         with tqdm(tree_list) as tree_pred_bar:
             tree_pred_bar.set_description('[Initial tree predictions]')
-            for tree in tree_pred_bar:
-                val_scores.append(tree.predict(X_val))
+            for idx_tree, tree in enumerate(tree_pred_bar):
+                val_predictions[idx_tree, :] = tree.predict(X_val)
             tree_pred_bar.update(1)
 
-        with tqdm(range(self._extracted_forest_size), disable=True) as pruning_forest_bar:
+        # boolean mask of trees to take into account for next evaluation of trees importance
+        mask_trees_to_consider = np.ones(val_predictions.shape[0], dtype=bool)
+        # the technique does backward selection, that is: trees are removed one after an other
+        nb_tree_to_remove = len(tree_list) - self._extracted_forest_size
+        with tqdm(range(nb_tree_to_remove), disable=True) as pruning_forest_bar:
             pruning_forest_bar.set_description(f'[Pruning forest s={self._extracted_forest_size}]')
-            for i in pruning_forest_bar:
-                best_similarity = 100000
-                found_index = 0
-                with tqdm(range(len(tree_list)), disable=True) as tree_list_bar:
-                    tree_list_bar.set_description(f'[Tree selection s={self._extracted_forest_size} #{i}]')
-                    for j in tree_list_bar:
-                        lonely_tree = tree_list[j]
-                        del tree_list[j]
-                        val_mean = np.mean(np.asarray(val_scores), axis=0)
-                        val_score = self._score_metric(val_mean, y_val)
-                        temp_similarity = abs(forest_pred - val_score)
-                        if (temp_similarity < best_similarity):
-                            found_index = j
-                            best_similarity = temp_similarity
-                        tree_list.insert(j, lonely_tree)
-                        val_scores.insert(j, lonely_tree.predict(X_val))
-                        tree_list_bar.update(1)
-                self._selected_trees.append(tree_list[found_index])
-                del tree_list[found_index]
-                del val_scores[found_index]
+            for _ in pruning_forest_bar:  # pour chaque arbre a extraire
+                # get indexes of trees to take into account
+                idx_trees_to_consider = np.arange(val_predictions.shape[0])[mask_trees_to_consider]
+                val_predictions_to_consider = val_predictions[idx_trees_to_consider]
+                nb_trees_to_consider = val_predictions_to_consider.shape[0]
+
+                if param == "similarity_predictions":
+                    # this matrix has zero on the diag and 1/(L-1) everywhere else.
+                    # When multiplying left the matrix of predictions (having L lines) by this zero_diag_matrix (square L), the result has on each
+                    # line, the average of all other lines in the initial matrix of predictions
+                    zero_diag_matrix = np.ones((nb_trees_to_consider, nb_trees_to_consider)) * (1 / (nb_trees_to_consider - 1))
+                    np.fill_diagonal(zero_diag_matrix, 0)
+
+                    leave_one_tree_out_predictions_val = zero_diag_matrix @ val_predictions_to_consider
+                    leave_one_tree_out_scores_val = self._score_metric(leave_one_tree_out_predictions_val, y_val)
+                    # difference with base forest is actually useless
+                    # delta_score = forest_score - leave_one_tree_out_scores_val
+
+                    # get index of tree to remove
+                    index_worse_tree = int(np.argmax(leave_one_tree_out_scores_val))  # correlation and MSE: both greater is worse
+
+                elif param == "similarity_similarities":
+                    correlation_matrix = val_predictions_to_consider @ val_predictions_to_consider.T
+                    average_correlation_by_tree = np.average(correlation_matrix, axis=1)
+
+                    # get index of tree to remove
+                    index_worse_tree = int(np.argmax(average_correlation_by_tree))  # correlation and MSE: both greater is worse
+
+                index_worse_tree_in_base_forest = idx_trees_to_consider[index_worse_tree]
+                trees_to_remove.append(tree_list[index_worse_tree_in_base_forest])
+                mask_trees_to_consider[index_worse_tree_in_base_forest] = False
                 pruning_forest_bar.update(1)
 
-        self._selected_trees = set(self._selected_trees)
-        pruned_forest = list(set(forest) - self._selected_trees)
+        pruned_forest = list(set(tree_list) - set(trees_to_remove))
+
+        self._selected_trees = pruned_forest
         self._estimator.estimators_ = pruned_forest
 
     def score(self, X, y):
-        test_list = list()
-        for mod in self._estimator.estimators_:
-            test_pred = mod.predict(X)
-            test_list.append(test_pred)
-        test_list = np.array(test_list)
-        test_mean = np.mean(test_list, axis=0)
-        score = self._score_metric(test_mean, y)
+        test_predictions = np.empty((len(self._estimator.estimators_), X.shape[0]))
+        for idx_tree, mod in enumerate(self._estimator.estimators_):
+            test_predictions[idx_tree, :] = mod.predict(X)
+
+        test_mean = np.mean(test_predictions, axis=0)
+        score = self._score_metric(test_mean, y)[0]
         return score
 
     def predict_base_estimator(self, X):