diff --git a/code/bolsonaro/models/forest_pruning_sota.py b/code/bolsonaro/models/forest_pruning_sota.py
new file mode 100644
index 0000000000000000000000000000000000000000..80d6fe7bc021f9ecc46f79e2bfe160cd3b820ac6
--- /dev/null
+++ b/code/bolsonaro/models/forest_pruning_sota.py
@@ -0,0 +1,111 @@
+import time
+
+from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
+from sklearn.metrics import mean_squared_error
+from sklearn.base import BaseEstimator
+from abc import abstractmethod, ABCMeta
+import numpy as np
+from tqdm import tqdm
+
+from bolsonaro.models.utils import score_metric_mse, aggregation_regression, aggregation_classification, score_metric_indicator
+
+
+class ForestPruningSOTA(BaseEstimator, metaclass=ABCMeta):
+
+    def __init__(self, models_parameters):
+        self._models_parameters = models_parameters
+        self._extracted_forest_size = self._models_parameters.extracted_forest_size
+        self._selected_trees = list()
+        self._base_estimator = self.init_estimator(models_parameters)
+
+    @staticmethod
+    @abstractmethod
+    def init_estimator(model_parameters):
+        pass
+
+    @abstractmethod
+    def _fit(self, X_train, y_train, X_val, y_val):
+        pass
+
+    @property
+    def models_parameters(self):
+        return self._models_parameters
+
+    @property
+    def selected_trees(self):
+        return self._selected_trees
+
+    def fit(self, X_train, y_train, X_val, y_val):
+        pruned_forest = self._fit(X_train, y_train, X_val, y_val)
+        assert len(pruned_forest) == self._extracted_forest_size, "Pruned forest size isn't the size of expected forest: {} != {}".format(len(pruned_forest), self._extracted_forest_size)
+        self._selected_trees = pruned_forest
+
+    def _base_estimator_predictions(self, X):
+        base_predictions = np.array([tree.predict(X) for tree in self._base_estimator.estimators_]).T
+        return base_predictions
+
+    def _selected_tree_predictions(self, X):
+        base_predictions = np.array([tree.predict(X) for tree in self.selected_trees]).T
+        return base_predictions
+
+    def predict(self, X):
+        predictions = self._selected_tree_predictions(X).T
+        final_predictions = self._aggregate(predictions)
+        return final_predictions
+
+    def predict_base_estimator(self, X):
+        return self._base_estimator.predict(X)
+
+    def score(self, X, y):
+        final_predictions = self.predict(X)
+        score = self._score_metric(final_predictions, y)[0]
+        return score
+
+    @staticmethod
+    @abstractmethod
+    def _best_score_idx(array):
+        """
+        return index of best element in array
+
+        :param array:
+        :return:
+        """
+        pass
+
+
+    @staticmethod
+    @abstractmethod
+    def _worse_score_idx(array):
+        """
+        return index of worse element in array
+
+        :param array:
+        :return:
+        """
+        pass
+
+
+    @abstractmethod
+    def _score_metric(self, y_preds, y_true):
+        """
+        get score of each predictors in y_preds
+
+        y_preds.shape == (nb_trees, nb_sample)
+        y_true.shape == (1, nb_sample)
+
+        :param y_preds:
+        :param y_true:
+        :return:
+        """
+        pass
+
+    @abstractmethod
+    def _aggregate(self, predictions):
+        """
+        Aggregates votes of predictors in predictions
+
+        predictions shape: (nb_trees, nb_samples)
+        :param predictions:
+        :return:
+        """
+        pass
\ No newline at end of file
diff --git a/code/bolsonaro/models/kmeans_forest_regressor.py b/code/bolsonaro/models/kmeans_forest_regressor.py
index ba1d9b4e56d98fbb2607d12917eb9d75b013044d..dc4788412554a039f5cf3c83a595cc96ab8fd987 100644
--- a/code/bolsonaro/models/kmeans_forest_regressor.py
+++ b/code/bolsonaro/models/kmeans_forest_regressor.py
@@ -1,5 +1,6 @@
 import time
 
+from bolsonaro.models.forest_pruning_sota import ForestPruningSOTA
 from bolsonaro.models.utils import score_metric_mse, score_metric_indicator, aggregation_classification, aggregation_regression
 from bolsonaro.utils import tqdm_joblib
 
@@ -14,38 +15,12 @@ from joblib import Parallel, delayed
 from tqdm import tqdm
 
 
-class KmeansForest(BaseEstimator, metaclass=ABCMeta):
+class KmeansForest(ForestPruningSOTA, metaclass=ABCMeta):
+    """
+    On extreme pruning of random forest ensembles for ral-time predictive applications', by Khaled Fawagreh, Mohamed Medhat Gaber and Eyad Elyan.
     """
-        On extreme pruning of random forest ensembles for ral-time predictive applications', by Khaled Fawagreh, Mohamed Medhat Gaber and Eyad Elyan.
-        """
-    def __init__(self, models_parameters):
-        self._models_parameters = models_parameters
-        self._extracted_forest_size = self._models_parameters.extracted_forest_size
-        self._selected_trees = list()
-        self._base_estimator = self.init_estimator(models_parameters)
-
-    @staticmethod
-    @abstractmethod
-    def init_estimator(model_parameters):
-        pass
-
-    def _base_estimator_predictions(self, X):
-        base_predictions = np.array([tree.predict(X) for tree in self._base_estimator.estimators_]).T
-        return base_predictions
-
-    def _selected_tree_predictions(self, X):
-        base_predictions = np.array([tree.predict(X) for tree in self.selected_trees]).T
-        return base_predictions
-
-    @property
-    def models_parameters(self):
-        return self._models_parameters
-
-    @property
-    def selected_trees(self):
-        return self._selected_trees
 
-    def fit(self, X_train, y_train, X_val, y_val):
+    def _fit(self, X_train, y_train, X_val, y_val):
         self._base_estimator.fit(X_train, y_train)
 
         predictions_val = self._base_estimator_predictions(X_val).T
@@ -62,65 +37,15 @@ class KmeansForest(BaseEstimator, metaclass=ABCMeta):
             best_tree_index = self._get_best_tree_index(predictions_val_cluster, y_val)
             lst_pruned_forest.append(self._base_estimator.estimators_[index_trees_cluster[best_tree_index]])
 
-        self._selected_trees = lst_pruned_forest
-
-    def score(self, X, y):
-        final_predictions = self.predict(X)
-        score = self._score_metric(final_predictions, y)[0]
-        return score
-
-    def predict(self, X):
-        predictions = self._selected_tree_predictions(X).T
-        final_predictions = self._aggregate(predictions)
-        return final_predictions
-
-    def predict_base_estimator(self, X):
-        return self._base_estimator.predict(X)
+        return lst_pruned_forest
 
     def _get_best_tree_index(self, y_preds, y_true):
         score = self._score_metric(y_preds, y_true)
-        best_tree_index = self._best(score)  # get best scoring tree (the one with lowest mse)
+        best_tree_index = self._best_score_idx(score)  # get best scoring tree (the one with lowest mse)
         return best_tree_index
 
-    @abstractmethod
-    def _score_metric(self, y_preds, y_true):
-        """
-        get score of each predictors in y_preds
-
-        y_preds.shape == (nb_trees, nb_sample)
-        y_true.shape == (1, nb_sample)
-
-        :param y_preds:
-        :param y_true:
-        :return:
-        """
-        pass
-
-    @staticmethod
-    @abstractmethod
-    def _best(array):
-        """
-        return index of best element in array
-
-        :param array:
-        :return:
-        """
-        pass
-
-    @abstractmethod
-    def _aggregate(self, predictions):
-        """
-        Aggregates votes of predictors in predictions
-
-        predictions shape: (nb_trees, nb_samples)
-        :param predictions:
-        :return:
-        """
-        pass
 
 class KMeansForestRegressor(KmeansForest, metaclass=ABCMeta):
-
-
     @staticmethod
     def init_estimator(model_parameters):
         return RandomForestRegressor(**model_parameters.hyperparameters,
@@ -133,12 +58,15 @@ class KMeansForestRegressor(KmeansForest, metaclass=ABCMeta):
         return score_metric_mse(y_preds, y_true)
 
     @staticmethod
-    def _best(array):
+    def _best_score_idx(array):
         return np.argmin(array)
 
+    @staticmethod
+    def _worse_score_idx(array):
+        return np.argmax(array)
 
-class KMeansForestClassifier(KmeansForest, metaclass=ABCMeta):
 
+class KMeansForestClassifier(KmeansForest, metaclass=ABCMeta):
     @staticmethod
     def init_estimator(model_parameters):
         return RandomForestClassifier(**model_parameters.hyperparameters,
@@ -161,5 +89,9 @@ class KMeansForestClassifier(KmeansForest, metaclass=ABCMeta):
         return predictions
 
     @staticmethod
-    def _best(array):
+    def _best_score_idx(array):
         return np.argmax(array)
+
+    @staticmethod
+    def _worse_score_idx(array):
+        return np.argmin(array)
diff --git a/code/bolsonaro/models/similarity_forest_regressor.py b/code/bolsonaro/models/similarity_forest_regressor.py
index 16710b0272ebb11f2fd3e8f962dc5ca46aca1e5b..95a035de2ae7f3004399c6928007153c531fa78d 100644
--- a/code/bolsonaro/models/similarity_forest_regressor.py
+++ b/code/bolsonaro/models/similarity_forest_regressor.py
@@ -7,53 +7,18 @@ from abc import abstractmethod, ABCMeta
 import numpy as np
 from tqdm import tqdm
 
+from bolsonaro.models.forest_pruning_sota import ForestPruningSOTA
 from bolsonaro.models.utils import score_metric_mse, aggregation_regression, aggregation_classification, score_metric_indicator
 
 
-class SimilarityForest(BaseEstimator, metaclass=ABCMeta):
+class SimilarityForest(ForestPruningSOTA, metaclass=ABCMeta):
     """
     https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2822360/
     """
     similarity_similarities = "similarity_similarities"
     similarity_predictions = "similarity_predictions"
 
-    def __init__(self, models_parameters):
-        self._models_parameters = models_parameters
-        self._extracted_forest_size = self._models_parameters.extracted_forest_size
-        self._selected_trees = list()
-        self._base_estimator = self.init_estimator(models_parameters)
-
-    @staticmethod
-    @abstractmethod
-    def init_estimator(model_parameters):
-        pass
-
-
-    @property
-    def models_parameters(self):
-        return self._models_parameters
-
-    @property
-    def selected_trees(self):
-        return self._selected_trees
-
-    def _base_estimator_predictions(self, X):
-        base_predictions = np.array([tree.predict(X) for tree in self._base_estimator.estimators_]).T
-        return base_predictions
-
-    def _selected_tree_predictions(self, X):
-        base_predictions = np.array([tree.predict(X) for tree in self.selected_trees]).T
-        return base_predictions
-
-    def predict(self, X):
-        predictions = self._selected_tree_predictions(X).T
-        final_predictions = self._aggregate(predictions)
-        return final_predictions
-
-    def predict_base_estimator(self, X):
-        return self._base_estimator.predict(X)
-
-    def fit(self, X_train, y_train, X_val, y_val):
+    def _fit(self, X_train, y_train, X_val, y_val):
         self._base_estimator.fit(X_train, y_train)
 
         param = self._models_parameters.extraction_strategy
@@ -91,7 +56,7 @@ class SimilarityForest(BaseEstimator, metaclass=ABCMeta):
                     # delta_score = forest_score - leave_one_tree_out_scores_val
 
                     # get index of tree to remove
-                    index_worse_tree = int(np.argmax(leave_one_tree_out_scores_val))  # correlation and MSE: both greater is worse
+                    index_worse_tree = int(self._worse_score_idx(leave_one_tree_out_scores_val))
 
                 elif param == self.similarity_similarities:
                     correlation_matrix = val_predictions_to_consider @ val_predictions_to_consider.T
@@ -109,34 +74,14 @@ class SimilarityForest(BaseEstimator, metaclass=ABCMeta):
                 pruning_forest_bar.update(1)
 
         pruned_forest = list(set(tree_list) - set(trees_to_remove))
-
-        self._selected_trees = pruned_forest
-
-    def score(self, X, y):
-        final_predictions = self.predict(X)
-        score = self._score_metric(final_predictions, y)[0]
-        return score
-
-    @abstractmethod
-    def _score_metric(self, y_preds, y_true):
-        pass
-
-    @abstractmethod
-    def _aggregate(self, predictions):
-        """
-        Aggregates votes of predictors in predictions
-
-        predictions shape: (nb_trees, nb_samples)
-        :param predictions:
-        :return:
-        """
-        pass
+        return pruned_forest
 
     @abstractmethod
     def _activation(self, leave_one_tree_out_predictions_val):
         pass
 
 
+
 class SimilarityForestRegressor(SimilarityForest, metaclass=ABCMeta):
 
     @staticmethod
@@ -153,6 +98,13 @@ class SimilarityForestRegressor(SimilarityForest, metaclass=ABCMeta):
     def _activation(self, predictions):
         return predictions
 
+    @staticmethod
+    def _best_score_idx(array):
+        return np.argmin(array)
+
+    @staticmethod
+    def _worse_score_idx(array):
+        return np.argmax(array)
 
 class SimilarityForestClassifier(SimilarityForest, metaclass=ABCMeta):
 
@@ -179,3 +131,11 @@ class SimilarityForestClassifier(SimilarityForest, metaclass=ABCMeta):
         predictions_0_1 = super()._base_estimator_predictions(X)
         predictions = (predictions_0_1 - 0.5) * 2
         return predictions
+
+    @staticmethod
+    def _best_score_idx(array):
+        return np.argmax(array)
+
+    @staticmethod
+    def _worse_score_idx(array):
+        return np.argmin(array)