diff --git a/code/bolsonaro/models/ensemble_selection_forest_regressor.py b/code/bolsonaro/models/ensemble_selection_forest_regressor.py
index 9f537785bc665478934d4a00df291ae1889cd8f2..1edbab0479b4553e290eb6eafcb1f1ad84a984e9 100644
--- a/code/bolsonaro/models/ensemble_selection_forest_regressor.py
+++ b/code/bolsonaro/models/ensemble_selection_forest_regressor.py
@@ -1,3 +1,6 @@
+import time
+
+from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
 from sklearn.metrics import mean_squared_error
 from sklearn.base import BaseEstimator
 from sklearn.tree import DecisionTreeRegressor
@@ -5,91 +8,129 @@ from abc import abstractmethod, ABCMeta
 import numpy as np
 from tqdm import tqdm
 
+from bolsonaro.models.forest_pruning_sota import ForestPruningSOTA
+from bolsonaro.models.utils import score_metric_mse, aggregation_regression, aggregation_classification, score_metric_indicator
+
 
-class EnsembleSelectionForestRegressor(BaseEstimator, metaclass=ABCMeta):
+class EnsembleSelectionForest(ForestPruningSOTA, metaclass=ABCMeta):
     """
     'Ensemble selection from libraries of models' by Rich Caruana et al
     """
 
-    def __init__(self, models_parameters, library, score_metric=mean_squared_error):
-        self._models_parameters = models_parameters
-        self._library = library
-        self._extracted_forest_size = self._models_parameters.extracted_forest_size
-        self._score_metric = score_metric
-        self._selected_trees = list()
-
-    @property
-    def models_parameters(self):
-        return self._models_parameters
-
-    @property
-    def library(self):
-        return self._library
-
-    @property
-    def selected_trees(self):
-        return self._selected_trees
-
-    def fit(self, X_train, y_train, X_val, y_val):
-        scores_list = list()
-        for estimator in self._library:
-            val_score = self._score_metric(estimator.predict(X_val), y_val)
-            scores_list.append(val_score)
-
-        class_list = list(self._library)
-        m = np.argmax(np.asarray(scores_list))
-        self._selected_trees = [class_list[m]]
-        temp_pred = class_list[m].predict(X_val)
-        del class_list[m]
-        for k in range(self._extracted_forest_size - 1):
-            candidate_index = 0
-            best_score = 100000
-            for j in range(len(class_list)):
-                temp_pred = np.vstack((temp_pred, class_list[j].predict(X_val)))
-                temp_mean = np.mean(temp_pred, axis=0)
-                temp_score = self._score_metric(temp_mean, y_val)
-                if (temp_score < best_score):
-                    candidate_index = j
-                    best_score = temp_score
-                temp_pred = np.delete(temp_pred, -1, 0)
-            self._selected_trees.append(class_list[candidate_index])
-            temp_pred = np.vstack((temp_pred, class_list[candidate_index].predict(X_val)))
-            del class_list[candidate_index]
-
-    def score(self, X, y):
-        predictions = self.predict_base_estimator(X)
-        return self._score_metric(predictions, y)
-
-    def predict_base_estimator(self, X):
-        predictions = list()
-        for tree in self._selected_trees:
-            predictions.append(tree.predict(X))
-        mean_predictions = np.mean(np.array(predictions), axis=0)
-        return mean_predictions
+    def _fit(self, X_train, y_train, X_val, y_val):
+        self._base_estimator.fit(X_train, y_train)
+
+        val_predictions = self._base_estimator_predictions(X_val).T
+        scores_predictions_val = self._score_metric(val_predictions, y_val)
+        idx_best_score = self._best_score_idx(scores_predictions_val)
+
+        lst_pruned_forest = [self._base_estimator.estimators_[idx_best_score]]
+
+        nb_selected_trees = 1
+        mean_so_far = val_predictions[idx_best_score]
+        while nb_selected_trees < self._extracted_forest_size:
+            # every new tree is selected with replacement as specified in the base paper
+            # mean update formula: u_{t+1} = (n_t * u_t + x_t) / (n_t + 1)
+            mean_prediction_subset_with_extra_tree = (nb_selected_trees * mean_so_far + val_predictions) / (nb_selected_trees + 1)
+            predictions_subset_with_extra_tree = self._activation(mean_prediction_subset_with_extra_tree)
+            scores_subset_with_extra_tree = self._score_metric(predictions_subset_with_extra_tree, y_val)
+            idx_best_extra_tree = self._best_score_idx(scores_subset_with_extra_tree)
+            lst_pruned_forest.append(self._base_estimator.estimators_[idx_best_extra_tree])
+
+            mean_so_far = mean_prediction_subset_with_extra_tree[idx_best_extra_tree]
+            nb_selected_trees += 1
+
+        return lst_pruned_forest
+
+
+    @abstractmethod
+    def _activation(self, leave_one_tree_out_predictions_val):
+        pass
+
+
+class EnsembleSelectionForestClassifier(EnsembleSelectionForest, metaclass=ABCMeta):
+    @staticmethod
+    def init_estimator(model_parameters):
+        return RandomForestClassifier(**model_parameters.hyperparameters,
+                                    random_state=model_parameters.seed, n_jobs=2)
+
+    def _aggregate(self, predictions):
+        return aggregation_classification(predictions)
+
+    def _score_metric(self, y_preds, y_true):
+        return score_metric_indicator(y_preds, y_true)
+
+    def _activation(self, predictions):
+        return np.sign(predictions)
+
+    def _selected_tree_predictions(self, X):
+        predictions_0_1 = super()._selected_tree_predictions(X)
+        predictions = (predictions_0_1 - 0.5) * 2
+        return predictions
+
+    def _base_estimator_predictions(self, X):
+        predictions_0_1 = super()._base_estimator_predictions(X)
+        predictions = (predictions_0_1 - 0.5) * 2
+        return predictions
 
     @staticmethod
-    def generate_library(X_train, y_train, random_state=None):
-        criterion_arr = ["mse"]#, "friedman_mse", "mae"]
-        splitter_arr = ["best"]#, "random"]
-        depth_arr = [i for i in range(5, 20, 1)]
-        min_samples_split_arr = [i for i in range(2, 20, 1)]
-        min_samples_leaf_arr = [i for i in range(2, 20, 1)]
-        max_features_arr = ["sqrt"]#["auto", "sqrt", "log2"]
-
-        library = list()
-        with tqdm(total=len(criterion_arr) * len(splitter_arr) * \
-            len(depth_arr) * len(min_samples_split_arr) * len(min_samples_leaf_arr) * \
-            len(max_features_arr)) as bar:
-            bar.set_description('Generating library')
-            for criterion in criterion_arr:
-                for splitter in splitter_arr:
-                    for depth in depth_arr:
-                        for min_samples_split in min_samples_split_arr:
-                            for min_samples_leaf in min_samples_leaf_arr:
-                                for max_features in max_features_arr:
-                                    t = DecisionTreeRegressor(criterion=criterion, splitter=splitter, max_depth=depth, min_samples_split=min_samples_split,
-                                        min_samples_leaf=min_samples_leaf, max_features=max_features, random_state=random_state)
-                                    t.fit(X_train, y_train)
-                                    library.append(t)
-                                    bar.update(1)
-        return library
+    def _best_score_idx(array):
+        return np.argmax(array)
+
+    @staticmethod
+    def _worse_score_idx(array):
+        return np.argmin(array)
+
+
+class EnsembleSelectionForestRegressor(EnsembleSelectionForest, metaclass=ABCMeta):
+
+    @staticmethod
+    def init_estimator(model_parameters):
+        return RandomForestRegressor(**model_parameters.hyperparameters,
+                              random_state=model_parameters.seed, n_jobs=2)
+
+    def _aggregate(self, predictions):
+        return aggregation_regression(predictions)
+
+    def _score_metric(self, y_preds, y_true):
+        return score_metric_mse(y_preds, y_true)
+
+    def _activation(self, predictions):
+        return predictions
+
+    @staticmethod
+    def _best_score_idx(array):
+        return np.argmin(array)
+
+    @staticmethod
+    def _worse_score_idx(array):
+        return np.argmax(array)
+
+
+
+    # @staticmethod
+    # def generate_library(X_train, y_train, random_state=None):
+    #     criterion_arr = ["mse"]#, "friedman_mse", "mae"]
+    #     splitter_arr = ["best"]#, "random"]
+    #     depth_arr = [i for i in range(5, 20, 1)]
+    #     min_samples_split_arr = [i for i in range(2, 20, 1)]
+    #     min_samples_leaf_arr = [i for i in range(2, 20, 1)]
+    #     max_features_arr = ["sqrt"]#["auto", "sqrt", "log2"]
+    #
+    #     library = list()
+    #     with tqdm(total=len(criterion_arr) * len(splitter_arr) * \
+    #         len(depth_arr) * len(min_samples_split_arr) * len(min_samples_leaf_arr) * \
+    #         len(max_features_arr)) as bar:
+    #         bar.set_description('Generating library')
+    #         for criterion in criterion_arr:
+    #             for splitter in splitter_arr:
+    #                 for depth in depth_arr:
+    #                     for min_samples_split in min_samples_split_arr:
+    #                         for min_samples_leaf in min_samples_leaf_arr:
+    #                             for max_features in max_features_arr:
+    #                                 t = DecisionTreeRegressor(criterion=criterion, splitter=splitter, max_depth=depth, min_samples_split=min_samples_split,
+    #                                     min_samples_leaf=min_samples_leaf, max_features=max_features, random_state=random_state)
+    #                                 t.fit(X_train, y_train)
+    #                                 library.append(t)
+    #                                 bar.update(1)
+    #     return library
diff --git a/code/bolsonaro/models/model_factory.py b/code/bolsonaro/models/model_factory.py
index d11af3b09b2538557f140d885c5f88ee1c8c97e7..4a70a1cf936ce3511dcedb5c2cd9aada0675523c 100644
--- a/code/bolsonaro/models/model_factory.py
+++ b/code/bolsonaro/models/model_factory.py
@@ -3,7 +3,7 @@ from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
 from bolsonaro.models.model_parameters import ModelParameters
 from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor, SimilarityForestClassifier
 from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor, KMeansForestClassifier
-from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor
+from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor, EnsembleSelectionForestClassifier
 from bolsonaro.data.task import Task
 
 from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
@@ -27,6 +27,8 @@ class ModelFactory(object):
             elif model_parameters.extraction_strategy == 'none':
                 return RandomForestClassifier(**model_parameters.hyperparameters,
                     random_state=model_parameters.seed)
+            elif model_parameters.extraction_strategy == 'ensemble':
+                return EnsembleSelectionForestClassifier(model_parameters)
             elif model_parameters.extraction_strategy == 'kmeans':
                 return KMeansForestClassifier(model_parameters)
             elif model_parameters.extraction_strategy in ['similarity_similarities', 'similarity_predictions']:
@@ -44,7 +46,7 @@ class ModelFactory(object):
             elif model_parameters.extraction_strategy == 'kmeans':
                 return KMeansForestRegressor(model_parameters)
             elif model_parameters.extraction_strategy == 'ensemble':
-                return EnsembleSelectionForestRegressor(model_parameters, library=library)
+                return EnsembleSelectionForestRegressor(model_parameters)
             elif model_parameters.extraction_strategy == 'none':
                 return RandomForestRegressor(**model_parameters.hyperparameters,
                     random_state=model_parameters.seed)
diff --git a/code/train.py b/code/train.py
index e7a319de1b1dcf87c51bb96b537d4d2df80499fb..07e3a74c0843402af29ff3346a2b5e4f3cf40985 100644
--- a/code/train.py
+++ b/code/train.py
@@ -55,7 +55,8 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb
 
     trainer = Trainer(dataset)
 
-    if parameters['extraction_strategy'] == 'ensemble':
+    # if parameters['extraction_strategy'] == 'ensemble':
+    if False:
         library = EnsembleSelectionForestRegressor.generate_library(dataset.X_train, dataset.y_train, random_state=seed)
     else:
         library = None