diff --git a/code/bolsonaro/models/kmeans_forest_regressor.py b/code/bolsonaro/models/kmeans_forest_regressor.py
index 6c0e3a501066feacdaaba0dad920a8232df870fc..591b653e16503a65fa6baf1ab83f2b014eab7093 100644
--- a/code/bolsonaro/models/kmeans_forest_regressor.py
+++ b/code/bolsonaro/models/kmeans_forest_regressor.py
@@ -13,17 +13,15 @@ from joblib import Parallel, delayed
 from tqdm import tqdm
 
 
-class KMeansForestRegressor(BaseEstimator, metaclass=ABCMeta):
+class KmeansForest(BaseEstimator, metaclass=ABCMeta):
     """
-    On extreme pruning of random forest ensembles for ral-time predictive applications', by Khaled Fawagreh, Mohamed Medhat Gaber and Eyad Elyan.
-    """
-
-    def __init__(self, models_parameters, score_metric=mean_squared_error):
+        On extreme pruning of random forest ensembles for ral-time predictive applications', by Khaled Fawagreh, Mohamed Medhat Gaber and Eyad Elyan.
+        """
+    def __init__(self, models_parameters):
         self._models_parameters = models_parameters
         self._estimator = RandomForestRegressor(**self._models_parameters.hyperparameters,
-            random_state=self._models_parameters.seed, n_jobs=2)
+                                                random_state=self._models_parameters.seed, n_jobs=2)
         self._extracted_forest_size = self._models_parameters.extracted_forest_size
-        self._score_metric = score_metric
         self._selected_trees = list()
 
     @property
@@ -37,7 +35,6 @@ class KMeansForestRegressor(BaseEstimator, metaclass=ABCMeta):
     def fit(self, X_train, y_train, X_val, y_val):
         self._estimator.fit(X_train, y_train)
 
-
         predictions_val = np.empty((len(self._estimator.estimators_), X_val.shape[0]))
         predictions = np.empty((len(self._estimator.estimators_), X_train.shape[0]))
         for i_tree, tree in enumerate(self._estimator.estimators_):
@@ -48,64 +45,84 @@ class KMeansForestRegressor(BaseEstimator, metaclass=ABCMeta):
         labels = np.array(kmeans.labels_)
 
         # start_np_version = time.time()
-        pruned_forest_1 = list()
+        lst_pruned_forest = list()
         for cluster_idx in range(self._extracted_forest_size):  # pourrait ĂȘtre parallelise
-            index_trees_cluster =  np.where(labels == cluster_idx)[0]
+            index_trees_cluster = np.where(labels == cluster_idx)[0]
             predictions_val_cluster = predictions_val[index_trees_cluster]  # get predictions of trees in cluster
-            if self._score_metric == mean_squared_error:
-                # compute the mean squared error of all trees at once usng numpy machinery
-                diff = predictions_val_cluster - y_val
-                squared_diff = diff ** 2
-                mean_squared_diff = np.mean(squared_diff, axis=1)
-
-                best_tree_index = np.argmin(mean_squared_diff) # get best scoring tree (the one with lowest mse)
-                pruned_forest_1.append(self._estimator.estimators_[index_trees_cluster[best_tree_index]])
-            else:
-                raise ValueError
-        # stop_np_version = time.time()
-        # print("Time np {}".format(stop_np_version - start_np_version))
-
-        # start_paralel_version = time.time()
-        # # For each cluster select the best tree on the validation set
-        # extracted_forest_sizes = list(range(self._extracted_forest_size))
-        # with tqdm_joblib(tqdm(total=self._extracted_forest_size, disable=True)) as prune_forest_job_pb:
-        #     pruned_forest = Parallel(n_jobs=2)(delayed(self._prune_forest_job)(prune_forest_job_pb, extracted_forest_sizes[i], labels, X_val, y_val, self._score_metric) for i in range(self._extracted_forest_size))
-        # stop_paralel_version = time.time()
-        # print("Time paralel {}".format(stop_paralel_version - start_paralel_version))
-        # assert all([t1 is t2 for (t1, t2) in zip(pruned_forest_1, pruned_forest)])
-
-        self._selected_trees = pruned_forest_1
-        self._estimator.estimators_ = pruned_forest_1
-
-    def _prune_forest_job(self, prune_forest_job_pb, c, labels, X_val, y_val, score_metric):
-        index = np.where(labels == c)[0]
-        with tqdm_joblib(tqdm(total=len(index), disable=True)) as cluster_job_pb:
-            cluster = Parallel(n_jobs=2)(delayed(self._cluster_job)(cluster_job_pb, index[i], X_val, y_val, score_metric) for i in range(len(index)))
-        best_tree_index = np.argmax(cluster)
-        prune_forest_job_pb.update()
-        return self._estimator.estimators_[index[best_tree_index]]
-
-    def _cluster_job(self, cluster_job_pb, i, X_val, y_val, score_metric):
-        y_val_pred = self._estimator.estimators_[i].predict(X_val)
-        tree_pred = score_metric(y_val, y_val_pred)
-        cluster_job_pb.update()
-        return tree_pred
+            best_tree_index = self._get_best_tree_index(predictions_val_cluster, y_val)
+            lst_pruned_forest.append(self._estimator.estimators_[index_trees_cluster[best_tree_index]])
 
-    def predict(self, X):
-        return self._estimator.predict(X)
+        self._selected_trees = lst_pruned_forest
+        self._estimator.estimators_ = lst_pruned_forest
 
     def score(self, X, y):
-        predictions = list()
-        for tree in self._estimator.estimators_:
-            predictions.append(tree.predict(X))
-        predictions = np.array(predictions)
-        mean_predictions = np.mean(predictions, axis=0)
-        score = self._score_metric(mean_predictions, y)
+        predictions = np.empty((len(self._estimator.estimators_), X.shape[0]))
+        for idx_tree, tree in enumerate(self._estimator.estimators_):
+            predictions[idx_tree, :] = tree.predict(X)
+        final_predictions = self._aggregate(predictions)
+        score = self._score_metric(final_predictions, y)[0]
         return score
 
+    def predict(self, X):
+        return self._estimator.predict(X)
+
     def predict_base_estimator(self, X):
         return self._estimator.predict(X)
 
+    @abstractmethod
+    def _score_metric(self, y_preds, y_true):
+        pass
+
+    @abstractmethod
+    def _get_best_tree_index(self, y_preds, y_true):
+        pass
+
+    @abstractmethod
+    def _aggregate(self, predictions):
+        pass
+
+class KMeansForestRegressor(KmeansForest, metaclass=ABCMeta):
+
+    def _aggregate(self, predictions):
+        return np.mean(predictions, axis=0)
+
+    def _score_metric(self, y_preds, y_true):
+        if len(y_true.shape) == 1:
+            y_true = y_true[np.newaxis, :]
+        if len(y_preds.shape) == 1:
+            y_preds = y_preds[np.newaxis, :]
+        assert y_preds.shape[1] == y_true.shape[1], "Number of examples to compare should be the same in y_preds and y_true"
+
+        diff = y_preds - y_true
+        squared_diff = diff ** 2
+        mean_squared_diff = np.mean(squared_diff, axis=1)
+        return mean_squared_diff
+
+    def _get_best_tree_index(self, y_preds, y_true):
+        score = self._score_metric(y_preds, y_true)
+        best_tree_index = np.argmin(score)  # get best scoring tree (the one with lowest mse)
+        return best_tree_index
+
+class KMeansForestClassifier(KmeansForest, metaclass=ABCMeta):
+
+    def _aggregate(self, predictions):
+        return np.sign(np.sum(predictions, axis=0))
+
+    def _score_metric(self, y_preds, y_true):
+        if len(y_true.shape) == 1:
+            y_true = y_true[np.newaxis, :]
+        if len(y_preds.shape) == 1:
+            y_preds = y_preds[np.newaxis, :]
+        assert y_preds.shape[1] == y_true.shape[1], "Number of examples to compare should be the same in y_preds and y_true"
+
+        bool_arr_correct_predictions = y_preds == y_true
+        return np.average(bool_arr_correct_predictions, axis=1)
+
+    def _get_best_tree_index(self, y_preds, y_true):
+        score = self._score_metric(y_preds, y_true)
+        best_tree_index = np.argmax(score)  # get best scoring tree (the one with lowest mse)
+        return best_tree_index
+
 if __name__ == "__main__":
     from sklearn import datasets
     from bolsonaro.models.model_parameters import ModelParameters
diff --git a/code/bolsonaro/models/model_factory.py b/code/bolsonaro/models/model_factory.py
index 335816b1dd33d28175f4865da2fddbbf73b8027d..785d65332045cd6c9db828cbe6810648e4c76d31 100644
--- a/code/bolsonaro/models/model_factory.py
+++ b/code/bolsonaro/models/model_factory.py
@@ -2,7 +2,7 @@ from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, Om
 from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
 from bolsonaro.models.model_parameters import ModelParameters
 from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor
-from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor
+from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor, KMeansForestClassifier
 from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor
 from bolsonaro.data.task import Task
 
@@ -27,6 +27,8 @@ class ModelFactory(object):
             elif model_parameters.extraction_strategy == 'none':
                 return RandomForestClassifier(**model_parameters.hyperparameters,
                     random_state=model_parameters.seed)
+            elif model_parameters.extraction_strategy == 'kmeans':
+                return KMeansForestClassifier(model_parameters)
             else:
                 raise ValueError('Invalid extraction strategy')
         elif task == Task.REGRESSION:
diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py
index 3327bee92401c3a993383c2d8b83a0ef80c206ba..49a63c1a004fa16d9dcf43424b96b1d512b2825d 100644
--- a/code/bolsonaro/trainer.py
+++ b/code/bolsonaro/trainer.py
@@ -2,7 +2,7 @@ from bolsonaro.models.model_raw_results import ModelRawResults
 from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
 from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier
 from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor
-from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor
+from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor, KMeansForestClassifier
 from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor
 from bolsonaro.error_handling.logger_factory import LoggerFactory
 from bolsonaro.data.task import Task
@@ -122,7 +122,7 @@ class Trainer(object):
                 y_pred = np.sign(y_pred)
                 y_pred = np.where(y_pred == 0, 1, y_pred)
             result = self._classification_score_metric(y_true, y_pred)
-        elif type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor]:
+        elif type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor, KMeansForestClassifier]:
             result = model.score(X, y_true)
         return result
 
@@ -139,6 +139,8 @@ class Trainer(object):
         elif type(model) is RandomForestRegressor:
             y_pred = model.predict(X)
             result = self._base_regression_score_metric(y_true, y_pred)
+        elif type(model) in [ KMeansForestClassifier]:
+            result = model.score(X, y_true)
         return result
 
     def compute_results(self, model, models_dir):