diff --git a/code/bolsonaro/models/kmeans_forest_regressor.py b/code/bolsonaro/models/kmeans_forest_regressor.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1a3dee940844a1e48a5fbd5df416bdea6eae903
--- /dev/null
+++ b/code/bolsonaro/models/kmeans_forest_regressor.py
@@ -0,0 +1,78 @@
+from bolsonaro.utils import tqdm_joblib
+
+from sklearn.ensemble import RandomForestRegressor
+from sklearn.metrics import mean_squared_error
+from sklearn.base import BaseEstimator
+from sklearn.cluster import KMeans
+from abc import abstractmethod, ABCMeta
+import numpy as np
+from scipy.stats import mode
+from joblib import Parallel, delayed
+from tqdm import tqdm
+
+
+class KMeansForestRegressor(BaseEstimator, metaclass=ABCMeta):
+    """
+    On extreme pruning of random forest ensembles for ral-time predictive applications', by Khaled Fawagreh, Mohamed Medhat Gaber and Eyad Elyan.
+    """
+
+    def __init__(self, models_parameters, score_metric=mean_squared_error):
+        self._models_parameters = models_parameters
+        self._estimator = RandomForestRegressor(**self._models_parameters.hyperparameters,
+            random_state=self._models_parameters.seed, n_jobs=-1)
+        self._extracted_forest_size = self._models_parameters.extracted_forest_size
+        self._score_metric = score_metric
+
+    @property
+    def models_parameters(self):
+        return self._models_parameters
+
+    def fit(self, X_train, y_train, X_val, y_val):
+        self._estimator.fit(X_train, y_train)
+
+        predictions = list()
+        for tree in self._estimator.estimators_:
+            predictions.append(tree.predict(X_train))
+        predictions = np.array(predictions)
+
+        kmeans = KMeans(n_clusters=self._extracted_forest_size, random_state=self._models_parameters.seed).fit(predictions)
+        labels = np.array(kmeans.labels_)
+
+        # For each cluster select the best tree on the validation set
+        extracted_forest_sizes = list(range(self._extracted_forest_size))
+        with tqdm_joblib(tqdm(total=self._extracted_forest_size, disable=True)) as prune_forest_job_pb:
+            pruned_forest = Parallel(n_jobs=-1)(delayed(self._prune_forest_job)(prune_forest_job_pb,
+                extracted_forest_sizes[i], labels, X_val, y_val, self._score_metric)
+                for i in range(self._extracted_forest_size))
+
+        self._estimator.estimators_ = pruned_forest
+
+    def _prune_forest_job(self, prune_forest_job_pb, c, labels, X_val, y_val, score_metric):
+        index = np.where(labels == c)[0]
+        with tqdm_joblib(tqdm(total=len(index), disable=True)) as cluster_job_pb:
+            cluster = Parallel(n_jobs=-1)(delayed(self._cluster_job)(cluster_job_pb, index[i], X_val, 
+                y_val, score_metric) for i in range(len(index)))
+        best_tree_index = np.argmax(cluster)
+        prune_forest_job_pb.update()
+        return self._estimator.estimators_[index[best_tree_index]]
+
+    def _cluster_job(self, cluster_job_pb, i, X_val, y_val, score_metric):
+        y_val_pred = self._estimator.estimators_[i].predict(X_val)
+        tree_pred = score_metric(y_val, y_val_pred)
+        cluster_job_pb.update()
+        return tree_pred
+
+    def predict(self, X):
+        return self._estimator.predict(X)
+
+    def score(self, X, y):
+        predictions = list()
+        for tree in self._estimator.estimators_:
+            predictions.append(tree.predict(X))
+        predictions = np.array(predictions)
+        mean_predictions = np.mean(predictions, axis=0)
+        score = self._score_metric(mean_predictions, y)
+        return score
+
+    def predict_base_estimator(self, X):
+        return self._estimator.predict(X)
diff --git a/code/bolsonaro/models/model_factory.py b/code/bolsonaro/models/model_factory.py
index 74993cc0a30b754595a490de40d69e064687bc24..bbda6cae89d218c7831780f71b9fc6a7bc022d54 100644
--- a/code/bolsonaro/models/model_factory.py
+++ b/code/bolsonaro/models/model_factory.py
@@ -2,6 +2,7 @@ from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, Om
 from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
 from bolsonaro.models.model_parameters import ModelParameters
 from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor
+from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor
 from bolsonaro.data.task import Task
 
 from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
@@ -22,9 +23,11 @@ class ModelFactory(object):
             elif model_parameters.extraction_strategy == 'random':
                 return RandomForestClassifier(n_estimators=model_parameters.extracted_forest_size,
                     random_state=model_parameters.seed)
-            else:
+            elif model_parameters.extraction_strategy == 'none':
                 return RandomForestClassifier(n_estimators=model_parameters.hyperparameters['n_estimators'],
                     random_state=model_parameters.seed)
+            else:
+                raise ValueError('Invalid extraction strategy')
         elif task == Task.REGRESSION:
             if model_parameters.extraction_strategy == 'omp':
                 return OmpForestRegressor(model_parameters)
@@ -33,15 +36,21 @@ class ModelFactory(object):
                     random_state=model_parameters.seed)
             elif model_parameters.extraction_strategy == 'similarity':
                 return SimilarityForestRegressor(model_parameters)
-            else:
+            elif model_parameters.extraction_strategy == 'kmeans':
+                return KMeansForestRegressor(model_parameters)
+            elif model_parameters.extraction_strategy == 'none':
                 return RandomForestRegressor(n_estimators=model_parameters.hyperparameters['n_estimators'],
                     random_state=model_parameters.seed)
+            else:
+                raise ValueError('Invalid extraction strategy')
         elif task == Task.MULTICLASSIFICATION:
             if model_parameters.extraction_strategy == 'omp':
                 return OmpForestMulticlassClassifier(model_parameters)
             elif model_parameters.extraction_strategy == 'random':
                 return RandomForestClassifier(n_estimators=model_parameters.extracted_forest_size,
                     random_state=model_parameters.seed)
-            else:
+            elif model_parameters.extraction_strategy == 'none':
                 return RandomForestClassifier(n_estimators=model_parameters.hyperparameters['n_estimators'],
                     random_state=model_parameters.seed)
+            else:
+                raise ValueError('Invalid extraction strategy')
diff --git a/code/bolsonaro/models/similarity_forest_regressor.py b/code/bolsonaro/models/similarity_forest_regressor.py
index f8d9c3ed349cf8c9e27acbcd7982694a65e11636..647e8695da88c0f84817a602471fd90f9bd1f1b0 100644
--- a/code/bolsonaro/models/similarity_forest_regressor.py
+++ b/code/bolsonaro/models/similarity_forest_regressor.py
@@ -3,6 +3,7 @@ from sklearn.metrics import mean_squared_error
 from sklearn.base import BaseEstimator
 from abc import abstractmethod, ABCMeta
 import numpy as np
+from tqdm import tqdm
 
 
 class SimilarityForestRegressor(BaseEstimator, metaclass=ABCMeta):
@@ -10,56 +11,69 @@ class SimilarityForestRegressor(BaseEstimator, metaclass=ABCMeta):
     https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2822360/
     """
 
-    def __init__(self, models_parameters):
+    def __init__(self, models_parameters, score_metric=mean_squared_error):
         self._models_parameters = models_parameters
-        self._regressor = RandomForestRegressor(n_estimators=self._models_parameters.hyperparameters['n_estimators'],
-            random_state=models_parameters.seed)
+        self._estimator = RandomForestRegressor(**self._models_parameters.hyperparameters,
+            random_state=self._models_parameters.seed, n_jobs=-1)
         self._extracted_forest_size = self._models_parameters.extracted_forest_size
+        self._score_metric = score_metric
 
     @property
     def models_parameters(self):
         return self._models_parameters
 
-    def fit(self, X_train, y_train, X_val, y_val, score_metric=mean_squared_error):
+    def fit(self, X_train, y_train, X_val, y_val):
+        self._estimator.fit(X_train, y_train)
 
-        self._regressor.fit(X_train, y_train)
-
-        y_val_pred = self._regressor.predict(X_val)
-        forest_pred = score_metric(y_val, y_val_pred)
-        forest = self._regressor.estimators_
+        y_val_pred = self._estimator.predict(X_val)
+        forest_pred = self._score_metric(y_val, y_val_pred)
+        forest = self._estimator.estimators_
         selected_trees = list()
-        tree_list = list(self._regressor.estimators_)
+        tree_list = list(self._estimator.estimators_)
+
+        val_scores = list()
+        with tqdm(tree_list) as tree_pred_bar:
+            tree_pred_bar.set_description('[Initial tree predictions]')
+            for tree in tree_pred_bar:
+                val_scores.append(tree.predict(X_val))
+            tree_pred_bar.update(1)
 
-        for _ in range(self._extracted_forest_size):
-            best_similarity = 100000
-            found_index = 0
-            for i in range(len(tree_list)):
-                lonely_tree = tree_list[i]
-                del tree_list[i]
-                val_list = list()
-                for tree in tree_list:
-                    val_pred = tree.predict(X_val)
-                    val_list.append(val_pred)
-                val_list = np.array(val_list)
-                val_mean = np.mean(val_list, axis=0)
-                val_score = score_metric(val_mean, y_val)
-                temp_similarity = abs(forest_pred - val_score)
-                if (temp_similarity < best_similarity):
-                    found_index = i
-                    best_similarity = temp_similarity
-                tree_list.insert(i, lonely_tree)
-            selected_trees.append(tree_list[found_index])
-            del tree_list[found_index]
+        with tqdm(range(self._extracted_forest_size), disable=True) as pruning_forest_bar:
+            pruning_forest_bar.set_description(f'[Pruning forest s={self._extracted_forest_size}]')
+            for i in pruning_forest_bar:
+                best_similarity = 100000
+                found_index = 0
+                with tqdm(range(len(tree_list)), disable=True) as tree_list_bar:
+                    tree_list_bar.set_description(f'[Tree selection s={self._extracted_forest_size} #{i}]')
+                    for j in tree_list_bar:
+                        lonely_tree = tree_list[j]
+                        del tree_list[j]
+                        val_mean = np.mean(np.asarray(val_scores), axis=0)
+                        val_score = self._score_metric(val_mean, y_val)
+                        temp_similarity = abs(forest_pred - val_score)
+                        if (temp_similarity < best_similarity):
+                            found_index = j
+                            best_similarity = temp_similarity
+                        tree_list.insert(j, lonely_tree)
+                        val_scores.insert(j, lonely_tree.predict(X_val))
+                        tree_list_bar.update(1)
+                selected_trees.append(tree_list[found_index])
+                del tree_list[found_index]
+                del val_scores[found_index]
+                pruning_forest_bar.update(1)
 
         pruned_forest = list(set(forest) - set(selected_trees))
-        self._regressor.estimators_ = pruned_forest
+        self._estimator.estimators_ = pruned_forest
 
     def score(self, X, y):
         test_list = list()
-        for mod in self._regressor.estimators_:
+        for mod in self._estimator.estimators_:
             test_pred = mod.predict(X)
             test_list.append(test_pred)
         test_list = np.array(test_list)
         test_mean = np.mean(test_list, axis=0)
-        score = mean_squared_error(test_mean, y)
+        score = self._score_metric(test_mean, y)
         return score
+
+    def predict_base_estimator(self, X):
+        return self._estimator.predict(X)
diff --git a/code/compute_results.py b/code/compute_results.py
index f15a7ff80249c538f2a408b564965de125b21cc4..5f7fac2c7718cf887d3d83a5b3a7eb9cdebfb9d9 100644
--- a/code/compute_results.py
+++ b/code/compute_results.py
@@ -400,23 +400,51 @@ if __name__ == "__main__":
             xlabel='Number of trees extracted',
             ylabel=experiments_score_metric,
             title='Loss values of {}\nusing best params of previous stages'.format(args.dataset_name))
+    elif args.stage == 5:
+        # Retreive the extracted forest sizes number used in order to have a base forest axis as long as necessary
+        extracted_forest_sizes_number = retreive_extracted_forest_sizes_number(args.models_dir, args.experiment_ids[1])
+
+        # base_with_params
+        logger.info('Loading base_with_params experiment scores...')
+        base_with_params_train_scores, base_with_params_dev_scores, base_with_params_test_scores, \
+            base_with_params_experiment_score_metric = \
+            extract_scores_across_seeds_and_forest_size(args.models_dir, args.results_dir, args.experiment_ids[0],
+            extracted_forest_sizes_number)
+        # random_with_params
+        logger.info('Loading random_with_params experiment scores...')
+        random_with_params_train_scores, random_with_params_dev_scores, random_with_params_test_scores, \
+            with_params_extracted_forest_sizes, random_with_params_experiment_score_metric = \
+            extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, args.experiment_ids[1])
+        # omp_with_params
+        logger.info('Loading omp_with_params experiment scores...')
+        omp_with_params_train_scores, omp_with_params_dev_scores, omp_with_params_test_scores, _, \
+            omp_with_params_experiment_score_metric = extract_scores_across_seeds_and_extracted_forest_sizes(
+                args.models_dir, args.results_dir, args.experiment_ids[2])
+        # omp_with_params
+        logger.info('Loading kmeans_with_params experiment scores...')
+        kmeans_with_params_train_scores, kmeans_with_params_dev_scores, kmeans_with_params_test_scores, _, \
+            kmeans_with_params_experiment_score_metric = extract_scores_across_seeds_and_extracted_forest_sizes(
+                args.models_dir, args.results_dir, args.experiment_ids[3])
+        
+        # Sanity check on the metrics retreived
+        if not (base_with_params_experiment_score_metric == random_with_params_experiment_score_metric
+            == omp_with_params_experiment_score_metric == kmeans_with_params_experiment_score_metric):
+            raise ValueError('Score metrics of all experiments must be the same.')
+        experiments_score_metric = base_with_params_experiment_score_metric
+
+        output_path = os.path.join(args.results_dir, args.dataset_name, 'stage5_kmeans')
+        pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)
 
-        # experiment_weights
-        #Plotter.weight_density(experiment_weights, output_path + os.sep + 'weight_density.png')
+        Plotter.plot_stage2_losses(
+            file_path=output_path + os.sep + 'losses.png',
+            all_experiment_scores=[base_with_params_test_scores, random_with_params_test_scores, omp_with_params_test_scores,
+                kmeans_with_params_test_scores],
+            all_labels=['base', 'random', 'omp', 'kmeans'],
+            x_value=with_params_extracted_forest_sizes,
+            xlabel='Number of trees extracted',
+            ylabel=experiments_score_metric,
+            title='Loss values of {}\nusing best params of previous stages'.format(args.dataset_name))
     else:
         raise ValueError('This stage number is not supported yet, but it will be!')
 
     logger.info('Done.')
-
-    """
-    TODO:
-    For each dataset:
-    Stage 1) [DONE for california_housing] A figure for the selection of the best base forest model hyperparameters (best vs default/random hyperparams)
-    Stage 2) [DONE for california_housing] A figure for the selection of the best combination of normalization: D normalization vs weights normalization (4 combinations)
-    Stage 3) [DONE for california_housing] A figure for the selection of the most relevant subsets combination: train,dev vs train+dev,train+dev vs train,train+dev
-    Stage 4) A figure to finally compare the perf of our approach using the previous selected
-        parameters vs the baseline vs other papers using different extracted forest size
-        (percentage of the tree size found previously in best hyperparams search) on the abscissa.
-
-    IMPORTANT: Compare experiments that used the same seeds among them (except for stage 1).
-    """
diff --git a/code/ensemble_selection.py b/code/ensemble_selection.py
new file mode 100644
index 0000000000000000000000000000000000000000..a09a85e9f726e1dbe1df962da53416bbbb0237e3
--- /dev/null
+++ b/code/ensemble_selection.py
@@ -0,0 +1,161 @@
+# Implemenation of the paper 'Ensemble selection from libraries of models' by Rich Caruana et al.
+# A set of trees is trained, then those performing the best on the dev set are added to the forest.
+
+
+
+from sklearn.datasets import fetch_california_housing
+from sklearn.model_selection import train_test_split
+from sklearn.tree import DecisionTreeRegressor
+from sklearn.externals import joblib
+import numpy as np
+from sklearn.metrics import r2_score
+import matplotlib.pyplot as plt
+
+(data, target) = fetch_california_housing(return_X_y=True)
+X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=10000, random_state=2019)
+X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=3000, random_state=2019)
+
+criterion_arr = ["mse"]#, "friedman_mse", "mae"]
+splitter_arr = ["best"]#, "random"]
+depth_arr = [i for i in range(5, 20, 1)]
+min_samples_split_arr = [i for i in range(2, 20, 1)]
+min_samples_leaf_arr = [i for i in range(2, 20, 1)]
+max_features_arr = ["sqrt"]#["auto", "sqrt", "log2"]
+
+library = list()
+
+for criterion in criterion_arr:
+    for splitter in splitter_arr:
+        for depth in depth_arr:
+            for min_samples_split in min_samples_split_arr:
+                for min_samples_leaf in min_samples_leaf_arr:
+                    for max_features in max_features_arr:
+                        t = DecisionTreeRegressor(criterion=criterion, splitter=splitter, max_depth=depth, min_samples_split=min_samples_split,
+                        min_samples_leaf=min_samples_leaf, max_features=max_features, random_state=2017)
+                        t.fit(X_train, y_train)
+                        #filename= "t_{}_{}_{}_{}_{}_{}.sav".format(criterion, splitter, depth, min_sample_split, min_sample_leaf, max_features)
+                        library.append(t)
+                        
+                                                
+print("classifiers", len(library))
+
+scores_list = list()
+
+for classif in library:
+    r2 = classif.score(X_val, y_val)
+    scores_list.append(r2)
+    
+print("scores", len(scores_list))
+#print(scores_list)
+
+##########################
+
+np_scores_list = np.array(scores_list)
+#sort_ind = np.argsort(np_scores_list)[::-1]
+#sorted_scores = [scores_list[i] for i in sort_ind]
+#sorted_class = [class_list[i] for i in sort_ind]
+
+#print(sorted_class)
+#print(sorted_scores)
+
+#res = list()
+#for s in [10, 20, 30]:
+#    best_class = sorted_class[:s]
+#    temp_res = list()
+#    for r in best_class:
+#        r2 = r.score(X_test, y_test)
+#        temp_res.append(r2)
+#    res.append(np.mean(temp_res))
+    
+#print("scores on test set", res)
+    
+
+###########################
+
+
+    
+
+
+
+
+#for k in range(num_sel_tree-1):
+#        cand_index = 0
+#        best_mean = 0
+#        #del scores_sel[-1]
+#        for j in range(len(scores_list)):
+#            scores_sel.append(scores_list[j])
+#            temp_scores_sel = np.array(scores_sel)
+#            temp_mean = np.mean(temp_scores_sel)
+#            if (temp_mean > best_mean):
+#                best_mean = temp_mean
+#                cand_index = j
+#            del scores_sel[-1]
+#        ens_sel.append(class_list[cand_index])
+#        scores_sel.append(scores_list[cand_index])
+#        del scores_list[cand_index]
+#        del class_list[cand_index]
+#print("selected models",ens_sel)
+#print("selected_scores", scores_sel)
+
+trees_in_forest = list()
+perf_prun_forest = list()
+
+for num_sel_tree in [2, 4, 6, 8, 10, 15, 20, 30, 40, 50]:
+    class_list = list(library)
+    print("class list", len(class_list))
+    m = np.argmax(np_scores_list)
+    ens_sel = [class_list[m]]
+    #scores_sel = [scores_list[m]]
+    #del scores_list[m]
+    temp_pred = class_list[m].predict(X_val)
+    del class_list[m]
+    #print("prima di entrare nel for", len(class_list))  
+    for k in range(num_sel_tree-1):
+        cand_index = 0
+        r2_best = -10000
+        #print("ad ogni loop", len(class_list))
+        for j in range(len(class_list)):
+            temp_pred = np.vstack((temp_pred, class_list[j].predict(X_val)))
+            temp_mean = np.mean(temp_pred, axis=0)
+            #print("temp pred and temp  mean shapes", temp_pred.shape, temp_mean.shape)
+            r2_temp = r2_score(y_val, temp_mean)
+            if (r2_temp > r2_best):
+                r2_best = r2_temp
+                cand_index = j
+            temp_pred = np.delete(temp_pred, -1, 0)
+            #print(temp_pred.shape)
+        ens_sel.append(class_list[cand_index])
+        #scores_sel.append(scores_list[cand_index])
+        temp_pred = np.vstack((temp_pred, class_list[cand_index].predict(X_val)))
+        #del scores_list[cand_index]
+        del class_list[cand_index]
+        
+    #print("ens_sel", len(ens_sel))
+    test_list = list()
+    for mod in ens_sel:
+        test_pred = mod.predict(X_test)
+        test_list.append(test_pred)
+        #print("scores sep", mod.score(X_test, y_test))
+        
+    test_list = np.array(test_list)
+    #print("test list shape", test_list.shape)
+    test_mean = np.mean(test_list, axis=0)
+    #print("test list shape", test_mean.shape)
+    r2_test = r2_score(test_mean, y_test)
+    #print(r2_test)
+    #print(ens_sel[0].score(X_test, y_test), ens_sel[1].score(X_test, y_test))
+    print(num_sel_tree, r2_test)
+    trees_in_forest.append(num_sel_tree)
+    perf_prun_forest.append(r2_test)
+
+
+print(trees_in_forest)
+print(perf_prun_forest)
+ax = plt.gca()
+ax.plot(trees_in_forest, perf_prun_forest, label='ensemble selection')
+ax.legend()
+#plt.title('fashion mnist')
+plt.xlabel('num trees')
+plt.ylabel('r2 score')
+plt.savefig("ensemble_selection.pdf") 
+plt.show()
diff --git a/code/forest_similarity.py b/code/forest_similarity.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f772a93109f26b23c889f4e2eb7b021ae85b3d0
--- /dev/null
+++ b/code/forest_similarity.py
@@ -0,0 +1,85 @@
+from sklearn.datasets import fetch_california_housing
+from sklearn.model_selection import train_test_split
+from sklearn.tree import DecisionTreeRegressor
+from sklearn.externals import joblib
+import numpy as np
+from sklearn.metrics import r2_score
+from sklearn.ensemble import RandomForestRegressor
+import matplotlib.pyplot as plt
+
+(data, target) = fetch_california_housing(return_X_y=True)
+X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=10000, random_state=2019)
+X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=3000, random_state=2019)
+
+num_trees = 100
+prun_for_size=[2, 4, 6, 8, 10, 12, 15, 20]
+
+randfor = RandomForestRegressor(num_trees, max_depth=7, random_state=2019)
+randfor.fit(X_train, y_train)
+randfor_pred = randfor.score(X_val, y_val)
+
+trees_forest = randfor.estimators_
+
+trees_in_forest = list()
+perf_prun_forest = list()
+
+for k in range(len(prun_for_size)):
+    ens_sel = list()
+    trees_list = list(randfor.estimators_)
+    #print("dovrebbe essere la taglia iniziale", len(trees_list))
+    for j in range(num_trees - prun_for_size[k]):
+        best_simil = 100000
+        cand_ind = 0
+        for i in range(len(trees_list)):
+            lonely_tree = trees_list[i]
+            del trees_list[i]
+            val_list = list()
+            #print("quando poto", len(trees_list))
+            for tree in trees_list:
+                val_pred = tree.predict(X_val)
+                val_list.append(val_pred)
+            val_list = np.array(val_list)
+            val_mean = np.mean(val_list, axis=0)
+            r2_val = r2_score(val_mean, y_val)
+            temp_simil = abs(randfor_pred-r2_val)
+            if (temp_simil < best_simil):
+                cand_ind = i
+                best_simil = temp_simil
+            trees_list.insert(i, lonely_tree)
+            #print("quando innesto", len(trees_list))
+        ens_sel.append(trees_list[cand_ind])
+        del trees_list[cand_ind]
+
+    prun_for = list(set(trees_forest) - set(ens_sel))
+    print("prun_for", len(prun_for))
+    print("trees forest", len(trees_forest))
+    print("ens_sel", len(ens_sel))
+
+    test_list = list()
+    for mod in prun_for:
+        test_pred = mod.predict(X_test)
+        test_list.append(test_pred)
+        #print("scores sep", mod.score(X_test, y_test))
+            
+    test_list = np.array(test_list)
+    #print("test list shape", test_list.shape)
+    test_mean = np.mean(test_list, axis=0)
+    #print("test list shape", test_mean.shape)
+    r2_test = r2_score(test_mean, y_test)
+    #print(r2_test)
+    #print(ens_sel[0].score(X_test, y_test), ens_sel[1].score(X_test, y_test))
+    print(len(prun_for), r2_test)
+    trees_in_forest.append(len(prun_for))
+    perf_prun_forest.append(r2_test)
+
+
+print(trees_in_forest)
+print(r2_test)
+ax = plt.gca()
+ax.plot(trees_in_forest, perf_prun_forest, label='pruned forest')
+ax.legend()
+#plt.title('fashion mnist')
+plt.xlabel('num trees')
+plt.ylabel('r2 score')
+plt.savefig("pruned_forest.pdf") 
+plt.show()
diff --git a/code/train.py b/code/train.py
index e51514cc254ee564993243a676b05d07e3aa7597..0e438cf89f7a187b274ab8c3d0a2f352de2a9e8a 100644
--- a/code/train.py
+++ b/code/train.py
@@ -21,7 +21,7 @@ import numpy as np
 import shutil
 
 
-def process_job(seed, parameters, experiment_id, hyperparameters):
+def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verbose):
     """
     Experiment function.
 
@@ -34,7 +34,6 @@ def process_job(seed, parameters, experiment_id, hyperparameters):
     """
     logger = LoggerFactory.create(LOG_PATH, 'training_seed{}_ti{}'.format(
         seed, threading.get_ident()))
-    logger.info('seed={}'.format(seed))
 
     seed_str = str(seed)
     experiment_id_str = str(experiment_id)
@@ -55,13 +54,31 @@ def process_job(seed, parameters, experiment_id, hyperparameters):
     trainer = Trainer(dataset)
 
     if parameters['extraction_strategy'] != 'none':
-        for extracted_forest_size in parameters['extracted_forest_size']:
-            logger.info('extracted_forest_size={}'.format(extracted_forest_size))
-            sub_models_dir = models_dir + os.sep + 'extracted_forest_sizes' + os.sep + str(extracted_forest_size)
-            pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True)
+        with tqdm_joblib(tqdm(total=len(parameters['extracted_forest_size']), disable=not verbose)) as extracted_forest_size_job_pb:
+            Parallel(n_jobs=-1)(delayed(extracted_forest_size_job)(extracted_forest_size_job_pb, parameters['extracted_forest_size'][i],
+                models_dir, seed, parameters, dataset, hyperparameters, experiment_id, trainer)
+                for i in range(len(parameters['extracted_forest_size'])))
+    else:
+        forest_size = hyperparameters['n_estimators']
+        logger.info('Base forest training with fixed forest size of {}'.format(forest_size))
+        sub_models_dir = models_dir + os.sep + 'forest_size' + os.sep + str(forest_size)
 
+        # Check if the result file already exists
+        already_exists = False
+        if os.path.isdir(sub_models_dir):
+            sub_models_dir_files = os.listdir(sub_models_dir)
+            for file_name in sub_models_dir_files:
+                if '.pickle' != os.path.splitext(file_name)[1]:
+                    continue
+                else:
+                    already_exists = os.path.getsize(os.path.join(sub_models_dir, file_name)) > 0
+                    break
+        if already_exists:
+            logger.info('Base forest result already exists. Skipping...')
+        else:
+            pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True)
             model_parameters = ModelParameters(
-                extracted_forest_size=extracted_forest_size,
+                extracted_forest_size=forest_size,
                 normalize_D=parameters['normalize_D'],
                 subsets_used=parameters['subsets_used'],
                 normalize_weights=parameters['normalize_weights'],
@@ -76,29 +93,50 @@ def process_job(seed, parameters, experiment_id, hyperparameters):
             trainer.init(model, subsets_used=parameters['subsets_used'])
             trainer.train(model)
             trainer.compute_results(model, sub_models_dir)
-    else:
-        forest_size = hyperparameters['n_estimators']
-        logger.info('Base forest training with fixed forest size of {}'.format(forest_size))
-        sub_models_dir = models_dir + os.sep + 'forest_size' + os.sep + str(forest_size)
-        pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True)
-
-        model_parameters = ModelParameters(
-            extracted_forest_size=forest_size,
-            normalize_D=parameters['normalize_D'],
-            subsets_used=parameters['subsets_used'],
-            normalize_weights=parameters['normalize_weights'],
-            seed=seed,
-            hyperparameters=hyperparameters,
-            extraction_strategy=parameters['extraction_strategy']
-        )
-        model_parameters.save(sub_models_dir, experiment_id)
-
-        model = ModelFactory.build(dataset.task, model_parameters)
-
-        trainer.init(model, subsets_used=parameters['subsets_used'])
-        trainer.train(model)
-        trainer.compute_results(model, sub_models_dir)
-    logger.info('Training done')
+    logger.info(f'Training done for seed {seed_str}')
+    seed_job_pb.update(1)
+
+def extracted_forest_size_job(extracted_forest_size_job_pb, extracted_forest_size, models_dir,
+    seed, parameters, dataset, hyperparameters, experiment_id, trainer):
+
+    logger = LoggerFactory.create(LOG_PATH, 'training_seed{}_extracted_forest_size{}_ti{}'.format(
+        seed, extracted_forest_size, threading.get_ident()))
+    logger.info('extracted_forest_size={}'.format(extracted_forest_size))
+
+    sub_models_dir = models_dir + os.sep + 'extracted_forest_sizes' + os.sep + str(extracted_forest_size)
+
+    # Check if the result file already exists
+    already_exists = False
+    if os.path.isdir(sub_models_dir):
+        sub_models_dir_files = os.listdir(sub_models_dir)
+        for file_name in sub_models_dir_files:
+            if '.pickle' != os.path.splitext(file_name)[1]:
+                return
+            else:
+                already_exists = os.path.getsize(os.path.join(sub_models_dir, file_name)) > 0
+                break
+    if already_exists:
+        logger.info(f'Extracted forest {extracted_forest_size} result already exists. Skipping...')
+        return
+
+    pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True)
+
+    model_parameters = ModelParameters(
+        extracted_forest_size=extracted_forest_size,
+        normalize_D=parameters['normalize_D'],
+        subsets_used=parameters['subsets_used'],
+        normalize_weights=parameters['normalize_weights'],
+        seed=seed,
+        hyperparameters=hyperparameters,
+        extraction_strategy=parameters['extraction_strategy']
+    )
+    model_parameters.save(sub_models_dir, experiment_id)
+
+    model = ModelFactory.build(dataset.task, model_parameters)
+
+    trainer.init(model, subsets_used=parameters['subsets_used'])
+    trainer.train(model)
+    trainer.compute_results(model, sub_models_dir)
 
 """
 Command lines example for stage 1:
@@ -138,6 +176,7 @@ if __name__ == "__main__":
     DEFAULT_SKIP_BEST_HYPERPARAMS = False
     DEFAULT_JOB_NUMBER = -1
     DEFAULT_EXTRACTION_STRATEGY = 'omp'
+    DEFAULT_OVERWRITE = False
 
     begin_random_seed_range = 1
     end_random_seed_range = 2000
@@ -163,7 +202,8 @@ if __name__ == "__main__":
     parser.add_argument('--skip_best_hyperparams', action='store_true', default=DEFAULT_SKIP_BEST_HYPERPARAMS, help='Do not use the best hyperparameters if there exist.')
     parser.add_argument('--save_experiment_configuration', nargs='+', default=None, help='Save the experiment parameters specified in the command line in a file. Args: {{stage_num}} {{name}}')
     parser.add_argument('--job_number', nargs='?', type=int, default=DEFAULT_JOB_NUMBER, help='Specify the number of job used during the parallelisation across seeds.')
-    parser.add_argument('--extraction_strategy', nargs='?', type=str, default=DEFAULT_EXTRACTION_STRATEGY, help='Specify the strategy to apply to extract the trees from the forest. Either omp, random, none or similarity.')
+    parser.add_argument('--extraction_strategy', nargs='?', type=str, default=DEFAULT_EXTRACTION_STRATEGY, help='Specify the strategy to apply to extract the trees from the forest. Either omp, random, none, similarity, kmeans.')
+    parser.add_argument('--overwrite', action='store_true', default=DEFAULT_OVERWRITE, help='Overwrite the experiment id')
     args = parser.parse_args()
 
     if args.experiment_configuration:
@@ -173,7 +213,7 @@ if __name__ == "__main__":
     else:
         parameters = args.__dict__
 
-    if parameters['extraction_strategy'] not in ['omp', 'random', 'none', 'similarity']:
+    if parameters['extraction_strategy'] not in ['omp', 'random', 'none', 'similarity', 'kmeans']:
         raise ValueError('Specified extraction strategy {} is not supported.'.format(parameters.extraction_strategy))
 
     pathlib.Path(parameters['models_dir']).mkdir(parents=True, exist_ok=True)
@@ -220,7 +260,8 @@ if __name__ == "__main__":
 
     if args.experiment_id:
         experiment_id = args.experiment_id
-        shutil.rmtree(os.path.join(parameters['models_dir'], str(experiment_id)), ignore_errors=True)
+        if args.overwrite:
+            shutil.rmtree(os.path.join(parameters['models_dir'], str(experiment_id)), ignore_errors=True)
     else:
         # Resolve the next experiment id number (last id + 1)
         experiment_id = resolve_experiment_id(parameters['models_dir'])
@@ -255,6 +296,6 @@ if __name__ == "__main__":
             )
 
     # Run as much job as there are seeds
-    with tqdm_joblib(tqdm(total=len(seeds), disable=not args.verbose)) as progress_bar:
-        Parallel(n_jobs=args.job_number)(delayed(process_job)(seeds[i],
-            parameters, experiment_id, hyperparameters) for i in range(len(seeds)))
+    with tqdm_joblib(tqdm(total=len(seeds), disable=not args.verbose)) as seed_job_pb:
+        Parallel(n_jobs=args.job_number)(delayed(seed_job)(seed_job_pb, seeds[i],
+            parameters, experiment_id, hyperparameters, args.verbose) for i in range(len(seeds)))
diff --git a/experiments/steel-plates/stage1/params.json b/experiments/steel-plates/stage1/params.json
new file mode 100644
index 0000000000000000000000000000000000000000..b9fe822796a2fbf4f0a4bdf681d9b7d89b6fb719
--- /dev/null
+++ b/experiments/steel-plates/stage1/params.json
@@ -0,0 +1,18 @@
+{
+    "scorer": "accuracy",
+    "best_score_train": 0.9879129734085416,
+    "best_score_test": 0.9938303341902314,
+    "best_parameters": {
+        "min_samples_leaf": 1,
+        "n_estimators": 1000,
+        "max_features": "log2",
+        "max_depth": 15
+    },
+    "random_seed": [
+        1604,
+        237,
+        1671,
+        1652,
+        37
+    ]
+}
\ No newline at end of file
diff --git a/scripts/run_stage5_experiments.sh b/scripts/run_stage5_experiments.sh
deleted file mode 100644
index c36433ee72d0092bd3aa79d9a08a093dde78a696..0000000000000000000000000000000000000000
--- a/scripts/run_stage5_experiments.sh
+++ /dev/null
@@ -1,12 +0,0 @@
-#!/bin/bash
-core_number=5
-walltime=1:00
-seeds='1 2 3'
-
-for dataset in diabetes #diamonds california_housing boston linnerud
-do
-    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=1:00 "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=none --extracted_forest_size_stop=0.40 --extracted_forest_size_samples=30 --experiment_id=1 --models_dir=models/$dataset/stage5 --subsets_used train+dev,train+dev"
-    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=1:00 "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=random --extracted_forest_size_stop=0.40 --extracted_forest_size_samples=30 --experiment_id=2 --models_dir=models/$dataset/stage5 --subsets_used train+dev,train+dev"
-    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=1:00 "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=omp --extracted_forest_size_stop=0.40 --extracted_forest_size_samples=30 --experiment_id=3 --models_dir=models/$dataset/stage5 --subsets_used train+dev,train+dev"
-    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=1:00 "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=similarity --extracted_forest_size_stop=0.40 --extracted_forest_size_samples=30 --experiment_id=4 --models_dir=models/$dataset/stage5 --subsets_used train+dev,train+dev"
-done