diff --git a/code/bolsonaro/data/dataset_loader.py b/code/bolsonaro/data/dataset_loader.py
index 0bebeaaf6c9f0dac2258a384a8b2bcff50c4c5b3..5e38b1eabb36b077812ebfcf94ac6a05c65e9211 100644
--- a/code/bolsonaro/data/dataset_loader.py
+++ b/code/bolsonaro/data/dataset_loader.py
@@ -1,7 +1,7 @@
 from bolsonaro.data.dataset import Dataset
 from bolsonaro.data.dataset_parameters import DatasetParameters
 from bolsonaro.data.task import Task
-from bolsonaro.utils import change_binary_func_load, change_binary_func_openml
+from bolsonaro.utils import change_binary_func_load, change_binary_func_openml, binarize_class_data
 
 from sklearn.datasets import load_boston, load_iris, load_diabetes, \
     load_digits, load_linnerud, load_wine, load_breast_cancer
@@ -81,7 +81,9 @@ class DatasetLoader(object):
         elif name == 'lfw_pairs':
             dataset = fetch_lfw_pairs()
             X, y = dataset.data, dataset.target
-            task = Task.MULTICLASSIFICATION
+            possible_classes = sorted(set(y))
+            y = binarize_class_data(y, possible_classes[-1])
+            task = Task.BINARYCLASSIFICATION
         elif name == 'covtype':
             X, y = fetch_covtype(random_state=dataset_parameters.random_state, shuffle=True, return_X_y=True)
             task = Task.MULTICLASSIFICATION
diff --git a/code/bolsonaro/models/ensemble_selection_forest_regressor.py b/code/bolsonaro/models/ensemble_selection_forest_regressor.py
index 9f537785bc665478934d4a00df291ae1889cd8f2..aa649fa4d21a04e51fb8429c8ef55e2867041057 100644
--- a/code/bolsonaro/models/ensemble_selection_forest_regressor.py
+++ b/code/bolsonaro/models/ensemble_selection_forest_regressor.py
@@ -1,3 +1,6 @@
+import time
+
+from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
 from sklearn.metrics import mean_squared_error
 from sklearn.base import BaseEstimator
 from sklearn.tree import DecisionTreeRegressor
@@ -5,91 +8,103 @@ from abc import abstractmethod, ABCMeta
 import numpy as np
 from tqdm import tqdm
 
+from bolsonaro.models.forest_pruning_sota import ForestPruningSOTA
+from bolsonaro.models.utils import score_metric_mse, aggregation_regression, aggregation_classification, score_metric_indicator
+
 
-class EnsembleSelectionForestRegressor(BaseEstimator, metaclass=ABCMeta):
+class EnsembleSelectionForest(ForestPruningSOTA, metaclass=ABCMeta):
     """
     'Ensemble selection from libraries of models' by Rich Caruana et al
     """
 
-    def __init__(self, models_parameters, library, score_metric=mean_squared_error):
-        self._models_parameters = models_parameters
-        self._library = library
-        self._extracted_forest_size = self._models_parameters.extracted_forest_size
-        self._score_metric = score_metric
-        self._selected_trees = list()
-
-    @property
-    def models_parameters(self):
-        return self._models_parameters
-
-    @property
-    def library(self):
-        return self._library
-
-    @property
-    def selected_trees(self):
-        return self._selected_trees
-
-    def fit(self, X_train, y_train, X_val, y_val):
-        scores_list = list()
-        for estimator in self._library:
-            val_score = self._score_metric(estimator.predict(X_val), y_val)
-            scores_list.append(val_score)
-
-        class_list = list(self._library)
-        m = np.argmax(np.asarray(scores_list))
-        self._selected_trees = [class_list[m]]
-        temp_pred = class_list[m].predict(X_val)
-        del class_list[m]
-        for k in range(self._extracted_forest_size - 1):
-            candidate_index = 0
-            best_score = 100000
-            for j in range(len(class_list)):
-                temp_pred = np.vstack((temp_pred, class_list[j].predict(X_val)))
-                temp_mean = np.mean(temp_pred, axis=0)
-                temp_score = self._score_metric(temp_mean, y_val)
-                if (temp_score < best_score):
-                    candidate_index = j
-                    best_score = temp_score
-                temp_pred = np.delete(temp_pred, -1, 0)
-            self._selected_trees.append(class_list[candidate_index])
-            temp_pred = np.vstack((temp_pred, class_list[candidate_index].predict(X_val)))
-            del class_list[candidate_index]
-
-    def score(self, X, y):
-        predictions = self.predict_base_estimator(X)
-        return self._score_metric(predictions, y)
-
-    def predict_base_estimator(self, X):
-        predictions = list()
-        for tree in self._selected_trees:
-            predictions.append(tree.predict(X))
-        mean_predictions = np.mean(np.array(predictions), axis=0)
-        return mean_predictions
+    def _fit(self, X_train, y_train, X_val, y_val):
+        self._base_estimator.fit(X_train, y_train)
+
+        val_predictions = self._base_estimator_predictions(X_val).T
+        scores_predictions_val = self._score_metric(val_predictions, y_val)
+        idx_best_score = self._best_score_idx(scores_predictions_val)
+
+        lst_pruned_forest = [self._base_estimator.estimators_[idx_best_score]]
+
+        nb_selected_trees = 1
+        mean_so_far = val_predictions[idx_best_score]
+        while nb_selected_trees < self._extracted_forest_size:
+            # every new tree is selected with replacement as specified in the base paper
+
+            # this matrix contains  at each line the predictions of the previous subset + the corresponding tree of the line
+            # mean update formula: u_{t+1} = (n_t * u_t + x_t) / (n_t + 1)
+            mean_prediction_subset_with_extra_tree = (nb_selected_trees * mean_so_far + val_predictions) / (nb_selected_trees + 1)
+            predictions_subset_with_extra_tree = self._activation(mean_prediction_subset_with_extra_tree)
+            scores_subset_with_extra_tree = self._score_metric(predictions_subset_with_extra_tree, y_val)
+            idx_best_extra_tree = self._best_score_idx(scores_subset_with_extra_tree)
+            lst_pruned_forest.append(self._base_estimator.estimators_[idx_best_extra_tree])
+
+            # update new mean prediction
+            mean_so_far = mean_prediction_subset_with_extra_tree[idx_best_extra_tree]
+            nb_selected_trees += 1
+
+        return lst_pruned_forest
+
+
+    @abstractmethod
+    def _activation(self, leave_one_tree_out_predictions_val):
+        pass
+
+
+class EnsembleSelectionForestClassifier(EnsembleSelectionForest, metaclass=ABCMeta):
+    @staticmethod
+    def init_estimator(model_parameters):
+        return RandomForestClassifier(**model_parameters.hyperparameters,
+                                    random_state=model_parameters.seed, n_jobs=-1)
+
+    def _aggregate(self, predictions):
+        return aggregation_classification(predictions)
+
+    def _score_metric(self, y_preds, y_true):
+        return score_metric_indicator(y_preds, y_true)
+
+    def _activation(self, predictions):
+        return np.sign(predictions)
+
+    def _selected_tree_predictions(self, X):
+        predictions_0_1 = super()._selected_tree_predictions(X)
+        predictions = (predictions_0_1 - 0.5) * 2
+        return predictions
+
+    def _base_estimator_predictions(self, X):
+        predictions_0_1 = super()._base_estimator_predictions(X)
+        predictions = (predictions_0_1 - 0.5) * 2
+        return predictions
+
+    @staticmethod
+    def _best_score_idx(array):
+        return np.argmax(array)
+
+    @staticmethod
+    def _worse_score_idx(array):
+        return np.argmin(array)
+
+
+class EnsembleSelectionForestRegressor(EnsembleSelectionForest, metaclass=ABCMeta):
+
+    @staticmethod
+    def init_estimator(model_parameters):
+        return RandomForestRegressor(**model_parameters.hyperparameters,
+                              random_state=model_parameters.seed, n_jobs=-1)
+
+    def _aggregate(self, predictions):
+        return aggregation_regression(predictions)
+
+    def _score_metric(self, y_preds, y_true):
+        return score_metric_mse(y_preds, y_true)
+
+    def _activation(self, predictions):
+        return predictions
+
+    @staticmethod
+    def _best_score_idx(array):
+        return np.argmin(array)
 
     @staticmethod
-    def generate_library(X_train, y_train, random_state=None):
-        criterion_arr = ["mse"]#, "friedman_mse", "mae"]
-        splitter_arr = ["best"]#, "random"]
-        depth_arr = [i for i in range(5, 20, 1)]
-        min_samples_split_arr = [i for i in range(2, 20, 1)]
-        min_samples_leaf_arr = [i for i in range(2, 20, 1)]
-        max_features_arr = ["sqrt"]#["auto", "sqrt", "log2"]
-
-        library = list()
-        with tqdm(total=len(criterion_arr) * len(splitter_arr) * \
-            len(depth_arr) * len(min_samples_split_arr) * len(min_samples_leaf_arr) * \
-            len(max_features_arr)) as bar:
-            bar.set_description('Generating library')
-            for criterion in criterion_arr:
-                for splitter in splitter_arr:
-                    for depth in depth_arr:
-                        for min_samples_split in min_samples_split_arr:
-                            for min_samples_leaf in min_samples_leaf_arr:
-                                for max_features in max_features_arr:
-                                    t = DecisionTreeRegressor(criterion=criterion, splitter=splitter, max_depth=depth, min_samples_split=min_samples_split,
-                                        min_samples_leaf=min_samples_leaf, max_features=max_features, random_state=random_state)
-                                    t.fit(X_train, y_train)
-                                    library.append(t)
-                                    bar.update(1)
-        return library
+    def _worse_score_idx(array):
+        return np.argmax(array)
diff --git a/code/bolsonaro/models/forest_pruning_sota.py b/code/bolsonaro/models/forest_pruning_sota.py
new file mode 100644
index 0000000000000000000000000000000000000000..79bc0068aa87936f1939be24e29acbe392364c4a
--- /dev/null
+++ b/code/bolsonaro/models/forest_pruning_sota.py
@@ -0,0 +1,109 @@
+from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
+from sklearn.metrics import mean_squared_error
+from sklearn.base import BaseEstimator
+from abc import abstractmethod, ABCMeta
+import numpy as np
+from tqdm import tqdm
+
+from bolsonaro.models.utils import score_metric_mse, aggregation_regression, aggregation_classification, score_metric_indicator
+
+
+class ForestPruningSOTA(BaseEstimator, metaclass=ABCMeta):
+
+    def __init__(self, models_parameters):
+        self._models_parameters = models_parameters
+        self._extracted_forest_size = self._models_parameters.extracted_forest_size
+        self._selected_trees = list()
+        self._base_estimator = self.init_estimator(models_parameters)
+
+    @staticmethod
+    @abstractmethod
+    def init_estimator(model_parameters):
+        pass
+
+    @abstractmethod
+    def _fit(self, X_train, y_train, X_val, y_val):
+        pass
+
+    @property
+    def models_parameters(self):
+        return self._models_parameters
+
+    @property
+    def selected_trees(self):
+        return self._selected_trees
+
+    def fit(self, X_train, y_train, X_val, y_val):
+        pruned_forest = self._fit(X_train, y_train, X_val, y_val)
+        assert len(pruned_forest) == self._extracted_forest_size, "Pruned forest size isn't the size of expected forest: {} != {}".format(len(pruned_forest), self._extracted_forest_size)
+        self._selected_trees = pruned_forest
+
+    def _base_estimator_predictions(self, X):
+        base_predictions = np.array([tree.predict(X) for tree in self._base_estimator.estimators_]).T
+        return base_predictions
+
+    def _selected_tree_predictions(self, X):
+        base_predictions = np.array([tree.predict(X) for tree in self.selected_trees]).T
+        return base_predictions
+
+    def predict(self, X):
+        predictions = self._selected_tree_predictions(X).T
+        final_predictions = self._aggregate(predictions)
+        return final_predictions
+
+    def predict_base_estimator(self, X):
+        return self._base_estimator.predict(X)
+
+    def score(self, X, y):
+        final_predictions = self.predict(X)
+        score = self._score_metric(final_predictions, y)[0]
+        return score
+
+    @staticmethod
+    @abstractmethod
+    def _best_score_idx(array):
+        """
+        return index of best element in array
+
+        :param array:
+        :return:
+        """
+        pass
+
+
+    @staticmethod
+    @abstractmethod
+    def _worse_score_idx(array):
+        """
+        return index of worse element in array
+
+        :param array:
+        :return:
+        """
+        pass
+
+
+    @abstractmethod
+    def _score_metric(self, y_preds, y_true):
+        """
+        get score of each predictors in y_preds
+
+        y_preds.shape == (nb_trees, nb_sample)
+        y_true.shape == (1, nb_sample)
+
+        :param y_preds:
+        :param y_true:
+        :return:
+        """
+        pass
+
+    @abstractmethod
+    def _aggregate(self, predictions):
+        """
+        Aggregates votes of predictors in predictions
+
+        predictions shape: (nb_trees, nb_samples)
+        :param predictions:
+        :return:
+        """
+        pass
\ No newline at end of file
diff --git a/code/bolsonaro/models/kmeans_forest_regressor.py b/code/bolsonaro/models/kmeans_forest_regressor.py
index d0d64120d1c391ae31d107d73ed22b1a2306e8c9..2c9be473441c0a356b5faf44037207d6ffdc0370 100644
--- a/code/bolsonaro/models/kmeans_forest_regressor.py
+++ b/code/bolsonaro/models/kmeans_forest_regressor.py
@@ -1,6 +1,10 @@
+import time
+
+from bolsonaro.models.forest_pruning_sota import ForestPruningSOTA
+from bolsonaro.models.utils import score_metric_mse, score_metric_indicator, aggregation_classification, aggregation_regression
 from bolsonaro.utils import tqdm_joblib
 
-from sklearn.ensemble import RandomForestRegressor
+from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
 from sklearn.metrics import mean_squared_error
 from sklearn.base import BaseEstimator
 from sklearn.cluster import KMeans
@@ -11,74 +15,83 @@ from joblib import Parallel, delayed
 from tqdm import tqdm
 
 
-class KMeansForestRegressor(BaseEstimator, metaclass=ABCMeta):
+class KmeansForest(ForestPruningSOTA, metaclass=ABCMeta):
     """
     On extreme pruning of random forest ensembles for ral-time predictive applications', by Khaled Fawagreh, Mohamed Medhat Gaber and Eyad Elyan.
     """
 
-    def __init__(self, models_parameters, score_metric=mean_squared_error):
-        self._models_parameters = models_parameters
-        self._estimator = RandomForestRegressor(**self._models_parameters.hyperparameters,
-            random_state=self._models_parameters.seed, n_jobs=2)
-        self._extracted_forest_size = self._models_parameters.extracted_forest_size
-        self._score_metric = score_metric
-        self._selected_trees = list()
+    def _fit(self, X_train, y_train, X_val, y_val):
+        self._base_estimator.fit(X_train, y_train)
 
-    @property
-    def models_parameters(self):
-        return self._models_parameters
+        predictions_val = self._base_estimator_predictions(X_val).T
+        predictions = self._base_estimator_predictions(X_train).T
 
-    @property
-    def selected_trees(self):
-        return self._selected_trees
+        kmeans = KMeans(n_clusters=self._extracted_forest_size, random_state=self._models_parameters.seed).fit(predictions)
+        labels = np.array(kmeans.labels_)
 
-    def fit(self, X_train, y_train, X_val, y_val):
-        self._estimator.fit(X_train, y_train)
+        # start_np_version = time.time()
+        lst_pruned_forest = list()
+        for cluster_idx in range(self._extracted_forest_size):  # pourrait être parallelise
+            index_trees_cluster = np.where(labels == cluster_idx)[0]
+            predictions_val_cluster = predictions_val[index_trees_cluster]  # get predictions of trees in cluster
+            best_tree_index = self._get_best_tree_index(predictions_val_cluster, y_val)
+            lst_pruned_forest.append(self._base_estimator.estimators_[index_trees_cluster[best_tree_index]])
 
-        predictions = list()
-        for tree in self._estimator.estimators_:
-            predictions.append(tree.predict(X_train))
-        predictions = np.array(predictions)
+        return lst_pruned_forest
 
-        kmeans = KMeans(n_clusters=self._extracted_forest_size, random_state=self._models_parameters.seed).fit(predictions)
-        labels = np.array(kmeans.labels_)
+    def _get_best_tree_index(self, y_preds, y_true):
+        score = self._score_metric(y_preds, y_true)
+        best_tree_index = self._best_score_idx(score)  # get best scoring tree (the one with lowest mse)
+        return best_tree_index
+
+
+class KMeansForestRegressor(KmeansForest, metaclass=ABCMeta):
+    @staticmethod
+    def init_estimator(model_parameters):
+        return RandomForestRegressor(**model_parameters.hyperparameters,
+                              random_state=model_parameters.seed, n_jobs=-1)
+
+    def _aggregate(self, predictions):
+        return aggregation_regression(predictions)
+
+    def _score_metric(self, y_preds, y_true):
+        return score_metric_mse(y_preds, y_true)
+
+    @staticmethod
+    def _best_score_idx(array):
+        return np.argmin(array)
+
+    @staticmethod
+    def _worse_score_idx(array):
+        return np.argmax(array)
+
+
+class KMeansForestClassifier(KmeansForest, metaclass=ABCMeta):
+    @staticmethod
+    def init_estimator(model_parameters):
+        return RandomForestClassifier(**model_parameters.hyperparameters,
+                                                random_state=model_parameters.seed, n_jobs=-1)
+
+    def _aggregate(self, predictions):
+        return aggregation_classification(predictions)
+
+    def _score_metric(self, y_preds, y_true):
+        return score_metric_indicator(y_preds, y_true)
+
+    def _selected_tree_predictions(self, X):
+        predictions_0_1 = super()._selected_tree_predictions(X)
+        predictions = (predictions_0_1 - 0.5) * 2
+        return predictions
+
+    def _base_estimator_predictions(self, X):
+        predictions_0_1 = super()._base_estimator_predictions(X)
+        predictions = (predictions_0_1 - 0.5) * 2
+        return predictions
+
+    @staticmethod
+    def _best_score_idx(array):
+        return np.argmax(array)
 
-        # For each cluster select the best tree on the validation set
-        extracted_forest_sizes = list(range(self._extracted_forest_size))
-        with tqdm_joblib(tqdm(total=self._extracted_forest_size, disable=True)) as prune_forest_job_pb:
-            pruned_forest = Parallel(n_jobs=2)(delayed(self._prune_forest_job)(prune_forest_job_pb,
-                extracted_forest_sizes[i], labels, X_val, y_val, self._score_metric)
-                for i in range(self._extracted_forest_size))
-
-        self._selected_trees = pruned_forest
-        self._estimator.estimators_ = pruned_forest
-
-    def _prune_forest_job(self, prune_forest_job_pb, c, labels, X_val, y_val, score_metric):
-        index = np.where(labels == c)[0]
-        with tqdm_joblib(tqdm(total=len(index), disable=True)) as cluster_job_pb:
-            cluster = Parallel(n_jobs=2)(delayed(self._cluster_job)(cluster_job_pb, index[i], X_val, 
-                y_val, score_metric) for i in range(len(index)))
-        best_tree_index = np.argmax(cluster)
-        prune_forest_job_pb.update()
-        return self._estimator.estimators_[index[best_tree_index]]
-
-    def _cluster_job(self, cluster_job_pb, i, X_val, y_val, score_metric):
-        y_val_pred = self._estimator.estimators_[i].predict(X_val)
-        tree_pred = score_metric(y_val, y_val_pred)
-        cluster_job_pb.update()
-        return tree_pred
-
-    def predict(self, X):
-        return self._estimator.predict(X)
-
-    def score(self, X, y):
-        predictions = list()
-        for tree in self._estimator.estimators_:
-            predictions.append(tree.predict(X))
-        predictions = np.array(predictions)
-        mean_predictions = np.mean(predictions, axis=0)
-        score = self._score_metric(mean_predictions, y)
-        return score
-
-    def predict_base_estimator(self, X):
-        return self._estimator.predict(X)
+    @staticmethod
+    def _worse_score_idx(array):
+        return np.argmin(array)
diff --git a/code/bolsonaro/models/model_factory.py b/code/bolsonaro/models/model_factory.py
index 335816b1dd33d28175f4865da2fddbbf73b8027d..6f1836278675662ac7ef57f8bf98cc4c8284dc26 100644
--- a/code/bolsonaro/models/model_factory.py
+++ b/code/bolsonaro/models/model_factory.py
@@ -1,9 +1,9 @@
 from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier
 from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
 from bolsonaro.models.model_parameters import ModelParameters
-from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor
-from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor
-from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor
+from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor, SimilarityForestClassifier
+from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor, KMeansForestClassifier
+from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor, EnsembleSelectionForestClassifier
 from bolsonaro.data.task import Task
 
 from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
@@ -14,12 +14,12 @@ import pickle
 class ModelFactory(object):
 
     @staticmethod
-    def build(task, model_parameters, library=None):
+    def build(task, model_parameters):
         if task not in [Task.BINARYCLASSIFICATION, Task.REGRESSION, Task.MULTICLASSIFICATION]:
             raise ValueError("Unsupported task '{}'".format(task))
 
         if task == Task.BINARYCLASSIFICATION:
-            if model_parameters.extraction_strategy == 'omp':
+            if model_parameters.extraction_strategy in ['omp', 'omp_distillation']:
                 return OmpForestBinaryClassifier(model_parameters)
             elif model_parameters.extraction_strategy == 'random':
                 return RandomForestClassifier(**model_parameters.hyperparameters,
@@ -27,27 +27,33 @@ class ModelFactory(object):
             elif model_parameters.extraction_strategy == 'none':
                 return RandomForestClassifier(**model_parameters.hyperparameters,
                     random_state=model_parameters.seed)
+            elif model_parameters.extraction_strategy == 'ensemble':
+                return EnsembleSelectionForestClassifier(model_parameters)
+            elif model_parameters.extraction_strategy == 'kmeans':
+                return KMeansForestClassifier(model_parameters)
+            elif model_parameters.extraction_strategy in ['similarity_similarities', 'similarity_predictions']:
+                return SimilarityForestClassifier(model_parameters)
             else:
                 raise ValueError('Invalid extraction strategy')
         elif task == Task.REGRESSION:
-            if model_parameters.extraction_strategy == 'omp':
+            if model_parameters.extraction_strategy in ['omp', 'omp_distillation']:
                 return OmpForestRegressor(model_parameters)
             elif model_parameters.extraction_strategy == 'random':
                 return RandomForestRegressor(**model_parameters.hyperparameters,
                     random_state=model_parameters.seed)
-            elif model_parameters.extraction_strategy == 'similarity':
+            elif model_parameters.extraction_strategy in ['similarity_similarities', 'similarity_predictions']:
                 return SimilarityForestRegressor(model_parameters)
             elif model_parameters.extraction_strategy == 'kmeans':
                 return KMeansForestRegressor(model_parameters)
             elif model_parameters.extraction_strategy == 'ensemble':
-                return EnsembleSelectionForestRegressor(model_parameters, library=library)
+                return EnsembleSelectionForestRegressor(model_parameters)
             elif model_parameters.extraction_strategy == 'none':
                 return RandomForestRegressor(**model_parameters.hyperparameters,
                     random_state=model_parameters.seed)
             else:
                 raise ValueError('Invalid extraction strategy')
         elif task == Task.MULTICLASSIFICATION:
-            if model_parameters.extraction_strategy == 'omp':
+            if model_parameters.extraction_strategy in ['omp', 'omp_distillation']:
                 return OmpForestMulticlassClassifier(model_parameters)
             elif model_parameters.extraction_strategy == 'random':
                 return RandomForestClassifier(**model_parameters.hyperparameters,
diff --git a/code/bolsonaro/models/model_raw_results.py b/code/bolsonaro/models/model_raw_results.py
index fbb80a591f9c1f42ac6fe3d1982d43d108faa026..3f7af5fcd31c1eb105a3dd39695e1ddc69f38676 100644
--- a/code/bolsonaro/models/model_raw_results.py
+++ b/code/bolsonaro/models/model_raw_results.py
@@ -9,7 +9,11 @@ class ModelRawResults(object):
     def __init__(self, model_weights, training_time,
         datetime, train_score, dev_score, test_score,
         train_score_base, dev_score_base,
-        test_score_base, score_metric, base_score_metric):
+        test_score_base, score_metric, base_score_metric,
+        #coherence='', correlation=''):
+        train_coherence='', dev_coherence='', test_coherence='',
+        train_correlation='', dev_correlation='', test_correlation='',
+        train_strength='', dev_strength='', test_strength=''):
 
         self._model_weights = model_weights
         self._training_time = training_time
@@ -22,6 +26,17 @@ class ModelRawResults(object):
         self._test_score_base = test_score_base
         self._score_metric = score_metric
         self._base_score_metric = base_score_metric
+        """self._coherence = coherence
+        self._correlation = correlation"""
+        self._train_coherence = train_coherence
+        self._dev_coherence = dev_coherence
+        self._test_coherence = test_coherence
+        self._train_correlation = train_correlation
+        self._dev_correlation = dev_correlation
+        self._test_correlation = test_correlation
+        self._train_strength = train_strength
+        self._dev_strength = dev_strength
+        self._test_strength = test_strength
 
     @property
     def model_weights(self):
@@ -67,6 +82,50 @@ class ModelRawResults(object):
     def base_score_metric(self):
         return self._base_score_metric
 
+    """@property
+    def coherence(self):
+        return self._coherence
+
+    @property
+    def correlation(self):
+        return self._correlation"""
+
+    @property
+    def train_coherence(self):
+        return self._train_coherence
+
+    @property
+    def dev_coherence(self):
+        return self._dev_coherence
+
+    @property
+    def test_coherence(self):
+        return self._test_coherence
+
+    @property
+    def train_correlation(self):
+        return self._train_correlation
+
+    @property
+    def dev_correlation(self):
+        return self._dev_correlation
+
+    @property
+    def test_correlation(self):
+        return self._test_correlation
+
+    @property
+    def train_strength(self):
+        return self._train_strength
+
+    @property
+    def dev_strength(self):
+        return self._dev_strength
+
+    @property
+    def test_strength(self):
+        return self._test_strength
+
     def save(self, models_dir):
         if not os.path.exists(models_dir):
             os.mkdir(models_dir)
diff --git a/code/bolsonaro/models/omp_forest.py b/code/bolsonaro/models/omp_forest.py
index d539f45314a244c410453bb84f726502c6ffe082..5918eea7a3f3cb2a67c0eb8712ab0405ef8fbd8e 100644
--- a/code/bolsonaro/models/omp_forest.py
+++ b/code/bolsonaro/models/omp_forest.py
@@ -28,18 +28,20 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta):
         return self._base_forest_estimator.score(X, y)
 
     def _base_estimator_predictions(self, X):
-        return np.array([tree.predict(X) for tree in self._base_forest_estimator.estimators_]).T
+        base_predictions = np.array([tree.predict(X) for tree in self._base_forest_estimator.estimators_]).T
+        return base_predictions
 
     @property
     def forest(self):
         return self._base_forest_estimator.estimators_
 
     # sklearn baseestimator api methods
-    def fit(self, X_forest, y_forest, X_omp, y_omp):
+    def fit(self, X_forest, y_forest, X_omp, y_omp, use_distillation=False):
         # print(y_forest.shape)
         # print(set([type(y) for y in y_forest]))
         self._base_forest_estimator.fit(X_forest, y_forest)
-        self._extract_subforest(X_omp, y_omp) # type: OrthogonalMatchingPursuit
+        self._extract_subforest(X_omp,
+            self.predict_base_estimator(X_omp) if use_distillation else y_omp) # type: OrthogonalMatchingPursuit
         return self
 
     def _extract_subforest(self, X, y):
@@ -151,11 +153,6 @@ class SingleOmpForest(OmpForest):
         """
         forest_predictions = np.array([tree.predict(X) for tree in self._base_forest_estimator.estimators_])
 
-        if self._models_parameters.normalize_D:
-            forest_predictions = forest_predictions.T
-            forest_predictions /= self._forest_norms
-            forest_predictions = forest_predictions.T
-
         weights = self._omp.coef_
         select_trees = np.mean(forest_predictions[weights != 0], axis=0)
         return select_trees
diff --git a/code/bolsonaro/models/omp_forest_classifier.py b/code/bolsonaro/models/omp_forest_classifier.py
index 7a22337b2fcf48b5181e4971836470e17d0f4f62..2381937b214ab37e0f6e18f96971df9606ec52e5 100644
--- a/code/bolsonaro/models/omp_forest_classifier.py
+++ b/code/bolsonaro/models/omp_forest_classifier.py
@@ -19,11 +19,16 @@ class OmpForestBinaryClassifier(SingleOmpForest):
     def _check_classes(self, y):
         assert len(set(y).difference({-1, 1})) == 0, "Classes for binary classifier must be {-1, +1}"
 
-    def fit(self, X_forest, y_forest, X_omp, y_omp):
+    def fit(self, X_forest, y_forest, X_omp, y_omp, use_distillation=False):
         self._check_classes(y_forest)
         self._check_classes(y_omp)
 
-        return super().fit(X_forest, y_forest, X_omp, y_omp)
+        return super().fit(X_forest, y_forest, X_omp, y_omp, use_distillation=use_distillation)
+
+    def _base_estimator_predictions(self, X):
+        predictions_0_1 = super()._base_estimator_predictions(X)
+        predictions = (predictions_0_1 - 0.5) * 2
+        return predictions
 
     def predict_no_weights(self, X):
         """
@@ -35,22 +40,15 @@ class OmpForestBinaryClassifier(SingleOmpForest):
         :return: a np.array of the predictions of the entire forest
         """
 
-        forest_predictions = np.array([tree.predict_proba(X) for tree in self._base_forest_estimator.estimators_])
-
-        if self._models_parameters.normalize_D:
-            forest_predictions = forest_predictions.T
-            forest_predictions /= self._forest_norms
-            forest_predictions = forest_predictions.T
+        forest_predictions = self._base_estimator_predictions(X)
 
         weights = self._omp.coef_
-        omp_trees_predictions = forest_predictions[weights != 0].T[1]
+        omp_trees_predictions = forest_predictions[:, weights != 0]
 
         # Here forest_pred is the probability of being class 1.
 
         result_omp = np.mean(omp_trees_predictions, axis=1)
 
-        result_omp = (result_omp - 0.5) * 2
-
         return result_omp
 
     def score(self, X, y, metric=DEFAULT_SCORE_METRIC):
diff --git a/code/bolsonaro/models/similarity_forest_regressor.py b/code/bolsonaro/models/similarity_forest_regressor.py
index 368d3741e49224c881dfe7a6de38ceec81a4f156..edbb8adcc55ceb95cc39e1b00122181760c14d63 100644
--- a/code/bolsonaro/models/similarity_forest_regressor.py
+++ b/code/bolsonaro/models/similarity_forest_regressor.py
@@ -1,84 +1,141 @@
-from sklearn.ensemble import RandomForestRegressor
+import time
+
+from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
 from sklearn.metrics import mean_squared_error
 from sklearn.base import BaseEstimator
 from abc import abstractmethod, ABCMeta
 import numpy as np
 from tqdm import tqdm
 
+from bolsonaro.models.forest_pruning_sota import ForestPruningSOTA
+from bolsonaro.models.utils import score_metric_mse, aggregation_regression, aggregation_classification, score_metric_indicator
+
 
-class SimilarityForestRegressor(BaseEstimator, metaclass=ABCMeta):
+class SimilarityForest(ForestPruningSOTA, metaclass=ABCMeta):
     """
     https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2822360/
     """
+    similarity_similarities = "similarity_similarities"
+    similarity_predictions = "similarity_predictions"
+
+    def _fit(self, X_train, y_train, X_val, y_val):
+        self._base_estimator.fit(X_train, y_train)
+
+        param = self._models_parameters.extraction_strategy
 
-    def __init__(self, models_parameters, score_metric=mean_squared_error):
-        self._models_parameters = models_parameters
-        self._estimator = RandomForestRegressor(**self._models_parameters.hyperparameters,
-            random_state=self._models_parameters.seed, n_jobs=-1)
-        self._extracted_forest_size = self._models_parameters.extracted_forest_size
-        self._score_metric = score_metric
-        self._selected_trees = list()
-
-    @property
-    def models_parameters(self):
-        return self._models_parameters
-
-    @property
-    def selected_trees(self):
-        return self._selected_trees
-
-    def fit(self, X_train, y_train, X_val, y_val):
-        self._estimator.fit(X_train, y_train)
-
-        y_val_pred = self._estimator.predict(X_val)
-        forest_pred = self._score_metric(y_val, y_val_pred)
-        forest = self._estimator.estimators_
-        tree_list = list(self._estimator.estimators_)
-
-        val_scores = list()
-        with tqdm(tree_list) as tree_pred_bar:
-            tree_pred_bar.set_description('[Initial tree predictions]')
-            for tree in tree_pred_bar:
-                val_scores.append(tree.predict(X_val))
-            tree_pred_bar.update(1)
-
-        with tqdm(range(self._extracted_forest_size), disable=True) as pruning_forest_bar:
+        # get score of base forest on val
+        tree_list = list(self._base_estimator.estimators_)        # get score of base forest on val
+        trees_to_remove = list()
+
+        # get score of each single tree of forest on val
+        val_predictions = self._base_estimator_predictions(X_val).T
+
+        # boolean mask of trees to take into account for next evaluation of trees importance
+        mask_trees_to_consider = np.ones(val_predictions.shape[0], dtype=bool)
+        # the technique does backward selection, that is: trees are removed one after an other
+        nb_tree_to_remove = len(tree_list) - self._extracted_forest_size
+        with tqdm(range(nb_tree_to_remove), disable=True) as pruning_forest_bar:
             pruning_forest_bar.set_description(f'[Pruning forest s={self._extracted_forest_size}]')
-            for i in pruning_forest_bar:
-                best_similarity = 100000
-                found_index = 0
-                with tqdm(range(len(tree_list)), disable=True) as tree_list_bar:
-                    tree_list_bar.set_description(f'[Tree selection s={self._extracted_forest_size} #{i}]')
-                    for j in tree_list_bar:
-                        lonely_tree = tree_list[j]
-                        del tree_list[j]
-                        val_mean = np.mean(np.asarray(val_scores), axis=0)
-                        val_score = self._score_metric(val_mean, y_val)
-                        temp_similarity = abs(forest_pred - val_score)
-                        if (temp_similarity < best_similarity):
-                            found_index = j
-                            best_similarity = temp_similarity
-                        tree_list.insert(j, lonely_tree)
-                        val_scores.insert(j, lonely_tree.predict(X_val))
-                        tree_list_bar.update(1)
-                self._selected_trees.append(tree_list[found_index])
-                del tree_list[found_index]
-                del val_scores[found_index]
+            for _ in pruning_forest_bar:  # pour chaque arbre a extraire
+                # get indexes of trees to take into account
+                idx_trees_to_consider = np.arange(val_predictions.shape[0])[mask_trees_to_consider]
+                val_predictions_to_consider = val_predictions[idx_trees_to_consider]
+                nb_trees_to_consider = val_predictions_to_consider.shape[0]
+
+                if param == self.similarity_predictions:
+                    # this matrix has zero on the diag and 1/(L-1) everywhere else.
+                    # When multiplying left the matrix of predictions (having L lines) by this zero_diag_matrix (square L), the result has on each
+                    # line, the average of all other lines in the initial matrix of predictions
+                    zero_diag_matrix = np.ones((nb_trees_to_consider, nb_trees_to_consider)) * (1 / (nb_trees_to_consider - 1))
+                    np.fill_diagonal(zero_diag_matrix, 0)
+
+                    leave_one_tree_out_predictions_val = zero_diag_matrix @ val_predictions_to_consider
+                    leave_one_tree_out_predictions_val = self._activation(leave_one_tree_out_predictions_val)  # identity for regression; sign for classification
+                    leave_one_tree_out_scores_val = self._score_metric(leave_one_tree_out_predictions_val, y_val)
+                    # difference with base forest is actually useless
+                    # delta_score = forest_score - leave_one_tree_out_scores_val
+
+                    # get index of tree to remove
+                    index_worse_tree = int(self._worse_score_idx(leave_one_tree_out_scores_val))
+
+                elif param == self.similarity_similarities:
+                    correlation_matrix = val_predictions_to_consider @ val_predictions_to_consider.T
+                    average_correlation_by_tree = np.average(correlation_matrix, axis=1)
+
+                    # get index of tree to remove
+                    index_worse_tree = int(np.argmax(average_correlation_by_tree))  # correlation and MSE: both greater is worse
+
+                else:
+                    raise ValueError("Unknown similarity method {}. Should be {} or {}".format(param, self.similarity_similarities, self.similarity_predictions))
+
+                index_worse_tree_in_base_forest = idx_trees_to_consider[index_worse_tree]
+                trees_to_remove.append(tree_list[index_worse_tree_in_base_forest])
+                mask_trees_to_consider[index_worse_tree_in_base_forest] = False
                 pruning_forest_bar.update(1)
 
-        self._selected_trees = set(self._selected_trees)
-        pruned_forest = list(set(forest) - self._selected_trees)
-        self._estimator.estimators_ = pruned_forest
-
-    def score(self, X, y):
-        test_list = list()
-        for mod in self._estimator.estimators_:
-            test_pred = mod.predict(X)
-            test_list.append(test_pred)
-        test_list = np.array(test_list)
-        test_mean = np.mean(test_list, axis=0)
-        score = self._score_metric(test_mean, y)
-        return score
-
-    def predict_base_estimator(self, X):
-        return self._estimator.predict(X)
+        pruned_forest = list(set(tree_list) - set(trees_to_remove))
+        return pruned_forest
+
+    @abstractmethod
+    def _activation(self, leave_one_tree_out_predictions_val):
+        pass
+
+
+
+class SimilarityForestRegressor(SimilarityForest, metaclass=ABCMeta):
+
+    @staticmethod
+    def init_estimator(model_parameters):
+        return RandomForestRegressor(**model_parameters.hyperparameters,
+                              random_state=model_parameters.seed, n_jobs=-1)
+
+    def _aggregate(self, predictions):
+        return aggregation_regression(predictions)
+
+    def _score_metric(self, y_preds, y_true):
+        return score_metric_mse(y_preds, y_true)
+
+    def _activation(self, predictions):
+        return predictions
+
+    @staticmethod
+    def _best_score_idx(array):
+        return np.argmin(array)
+
+    @staticmethod
+    def _worse_score_idx(array):
+        return np.argmax(array)
+
+class SimilarityForestClassifier(SimilarityForest, metaclass=ABCMeta):
+
+    @staticmethod
+    def init_estimator(model_parameters):
+        return RandomForestClassifier(**model_parameters.hyperparameters,
+                                    random_state=model_parameters.seed, n_jobs=-1)
+
+    def _aggregate(self, predictions):
+        return aggregation_classification(predictions)
+
+    def _score_metric(self, y_preds, y_true):
+        return score_metric_indicator(y_preds, y_true)
+
+    def _activation(self, predictions):
+        return np.sign(predictions)
+
+    def _selected_tree_predictions(self, X):
+        predictions_0_1 = super()._selected_tree_predictions(X)
+        predictions = (predictions_0_1 - 0.5) * 2
+        return predictions
+
+    def _base_estimator_predictions(self, X):
+        predictions_0_1 = super()._base_estimator_predictions(X)
+        predictions = (predictions_0_1 - 0.5) * 2
+        return predictions
+
+    @staticmethod
+    def _best_score_idx(array):
+        return np.argmax(array)
+
+    @staticmethod
+    def _worse_score_idx(array):
+        return np.argmin(array)
diff --git a/code/bolsonaro/models/utils.py b/code/bolsonaro/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a054407ea17b88d95328458ad8d86aa6030ff202
--- /dev/null
+++ b/code/bolsonaro/models/utils.py
@@ -0,0 +1,29 @@
+import numpy as np
+
+def score_metric_mse(y_preds, y_true):
+    if len(y_true.shape) == 1:
+        y_true = y_true[np.newaxis, :]
+    if len(y_preds.shape) == 1:
+        y_preds = y_preds[np.newaxis, :]
+    assert y_preds.shape[1] == y_true.shape[1], "Number of examples to compare should be the same in y_preds and y_true"
+
+    diff = y_preds - y_true
+    squared_diff = diff ** 2
+    mean_squared_diff = np.mean(squared_diff, axis=1)
+    return mean_squared_diff
+
+def score_metric_indicator(y_preds, y_true):
+    if len(y_true.shape) == 1:
+        y_true = y_true[np.newaxis, :]
+    if len(y_preds.shape) == 1:
+        y_preds = y_preds[np.newaxis, :]
+    assert y_preds.shape[1] == y_true.shape[1], "Number of examples to compare should be the same in y_preds and y_true"
+
+    bool_arr_correct_predictions = y_preds == y_true
+    return np.average(bool_arr_correct_predictions, axis=1)
+
+def aggregation_classification(predictions):
+    return np.sign(np.sum(predictions, axis=0))
+
+def aggregation_regression(predictions):
+    return np.mean(predictions, axis=0)
diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py
index 3327bee92401c3a993383c2d8b83a0ef80c206ba..78f2c082e4a9c20dfe7b6b5dfa2d5d49aca99cc2 100644
--- a/code/bolsonaro/trainer.py
+++ b/code/bolsonaro/trainer.py
@@ -1,15 +1,16 @@
 from bolsonaro.models.model_raw_results import ModelRawResults
 from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
 from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier
-from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor
-from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor
-from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor
+from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor, SimilarityForestClassifier
+from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor, KMeansForestClassifier
+from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor, EnsembleSelectionForestClassifier
 from bolsonaro.error_handling.logger_factory import LoggerFactory
 from bolsonaro.data.task import Task
 from . import LOG_PATH
 
 from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
 from sklearn.metrics import mean_squared_error, accuracy_score
+from sklearn.preprocessing import normalize
 import time
 import datetime
 import numpy as np
@@ -38,7 +39,6 @@ class Trainer(object):
             else classification_score_metric.__name__
         self._base_score_metric_name = base_regression_score_metric.__name__ if dataset.task == Task.REGRESSION \
             else base_classification_score_metric.__name__
-        self._selected_trees = ''
 
     @property
     def score_metric_name(self):
@@ -77,7 +77,7 @@ class Trainer(object):
         else:
             raise ValueError("Unknown specified subsets_used parameter '{}'".format(model.models_parameters.subsets_used))
 
-    def train(self, model, extracted_forest_size=None):
+    def train(self, model, extracted_forest_size=None, seed=None, use_distillation=False):
         """
         :param model: An instance of either RandomForestRegressor, RandomForestClassifier, OmpForestRegressor,
             OmpForestBinaryClassifier, OmpForestMulticlassClassifier.
@@ -88,6 +88,7 @@ class Trainer(object):
         if type(model) in [RandomForestRegressor, RandomForestClassifier]:
             if extracted_forest_size is not None:
                 estimators_index = np.arange(len(model.estimators_))
+                np.random.seed(seed)
                 np.random.shuffle(estimators_index)
                 choosen_estimators = estimators_index[:extracted_forest_size]
                 model.estimators_ = np.array(model.estimators_)[choosen_estimators]
@@ -96,14 +97,23 @@ class Trainer(object):
                     X=self._X_forest,
                     y=self._y_forest
                 )
-            self._selected_trees = model.estimators_
         else:
-            model.fit(
-                self._X_forest,
-                self._y_forest,
-                self._X_omp,
-                self._y_omp
-            )
+            if type(model) in [OmpForestRegressor, OmpForestBinaryClassifier, OmpForestMulticlassClassifier] and \
+                use_distillation:
+                model.fit(
+                    self._X_forest, # X_train or X_train+X_dev
+                    self._y_forest,
+                    self._X_omp, # X_train+X_dev or X_dev
+                    self._y_omp,
+                    use_distillation=use_distillation
+                )
+            else:
+                model.fit(
+                    self._X_forest, # X_train or X_train+X_dev
+                    self._y_forest,
+                    self._X_omp, # X_train+X_dev or X_dev
+                    self._y_omp
+                )
         self._end_time = time.time()
 
     def __score_func(self, model, X, y_true, weights=True):
@@ -122,7 +132,8 @@ class Trainer(object):
                 y_pred = np.sign(y_pred)
                 y_pred = np.where(y_pred == 0, 1, y_pred)
             result = self._classification_score_metric(y_true, y_pred)
-        elif type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor]:
+        elif type(model) in [SimilarityForestRegressor, SimilarityForestClassifier, KMeansForestRegressor, EnsembleSelectionForestRegressor, KMeansForestClassifier,
+            EnsembleSelectionForestClassifier]:
             result = model.score(X, y_true)
         return result
 
@@ -130,7 +141,7 @@ class Trainer(object):
         if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor]:
             y_pred = model.predict_base_estimator(X)
             result = self._base_regression_score_metric(y_true, y_pred)
-        elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
+        elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier, KMeansForestClassifier, SimilarityForestClassifier, EnsembleSelectionForestClassifier]:
             y_pred = model.predict_base_estimator(X)
             result = self._base_classification_score_metric(y_true, y_pred)
         elif type(model) == RandomForestClassifier:
@@ -141,7 +152,17 @@ class Trainer(object):
             result = self._base_regression_score_metric(y_true, y_pred)
         return result
 
-    def compute_results(self, model, models_dir):
+    def _evaluate_predictions(self, X, aggregation_function, selected_trees):
+        predictions = np.array([tree.predict(X) for tree in selected_trees])
+
+        predictions = normalize(predictions)
+
+        return aggregation_function(np.abs((predictions @ predictions.T - np.eye(len(predictions)))))
+
+    def _compute_forest_strength(self, X, y, metric_function, selected_trees):
+        return np.mean([metric_function(y, tree.predict(X)) for tree in selected_trees])
+
+    def compute_results(self, model, models_dir, subsets_used='train+dev,train+dev'):
         """
         :param model: Object with
         :param models_dir: Where the results will be saved
@@ -155,25 +176,72 @@ class Trainer(object):
         elif type(model) == OmpForestBinaryClassifier:
             model_weights = model._omp
 
-        if type(model) in [SimilarityForestRegressor, EnsembleSelectionForestRegressor, KMeansForestRegressor]:
-            self._selected_trees = model.selected_trees
+        if type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor, 
+            SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]:
+            selected_trees = model.selected_trees
+        elif type(model) in [OmpForestRegressor, OmpForestMulticlassClassifier, OmpForestBinaryClassifier]:
+            selected_trees = np.asarray(model.forest)[model._omp.coef_ != 0]
+        elif type(model) in [RandomForestRegressor, RandomForestClassifier]:
+            selected_trees = model.estimators_
 
-        if len(self._selected_trees) > 0:
+        if len(selected_trees) > 0:
+            target_selected_tree = int(os.path.split(models_dir)[-1])
+            if target_selected_tree != len(selected_trees):
+                raise ValueError(f'Invalid selected tree number target_selected_tree:{target_selected_tree} - len(selected_trees):{len(selected_trees)}')
             with open(os.path.join(models_dir, 'selected_trees.pickle'), 'wb') as output_file:
-                pickle.dump(self._selected_trees, output_file)
+                pickle.dump(selected_trees, output_file)
+
+        strength_metric = self._regression_score_metric if self._dataset.task == Task.REGRESSION else self._classification_score_metric
+
+        # Reeeally dirty to put that here but otherwise it's not thread safe...
+        if type(model) in [RandomForestRegressor, RandomForestClassifier]:
+            if subsets_used == 'train,dev':
+                X_forest = self._dataset.X_train
+                y_forest = self._dataset.y_train
+            else:
+                X_forest = np.concatenate([self._dataset.X_train, self._dataset.X_dev])
+                y_forest = np.concatenate([self._dataset.y_train, self._dataset.y_dev])
+            X_omp = self._dataset.X_dev
+            y_omp = self._dataset.y_dev
+        elif model.models_parameters.subsets_used == 'train,dev':
+            X_forest = self._dataset.X_train
+            y_forest = self._dataset.y_train
+            X_omp = self._dataset.X_dev
+            y_omp = self._dataset.y_dev
+        elif model.models_parameters.subsets_used == 'train+dev,train+dev':
+            X_forest = np.concatenate([self._dataset.X_train, self._dataset.X_dev])
+            X_omp = X_forest
+            y_forest = np.concatenate([self._dataset.y_train, self._dataset.y_dev])
+            y_omp = y_forest
+        elif model.models_parameters.subsets_used == 'train,train+dev':
+            X_forest = self._dataset.X_train
+            y_forest = self._dataset.y_train
+            X_omp = np.concatenate([self._dataset.X_train, self._dataset.X_dev])
+            y_omp = np.concatenate([self._dataset.y_train, self._dataset.y_dev])
+        else:
+            raise ValueError("Unknown specified subsets_used parameter '{}'".format(model.models_parameters.subsets_used))
 
         results = ModelRawResults(
             model_weights=model_weights,
             training_time=self._end_time - self._begin_time,
             datetime=datetime.datetime.now(),
-            train_score=self.__score_func(model, self._dataset.X_train, self._dataset.y_train),
-            dev_score=self.__score_func(model, self._dataset.X_dev, self._dataset.y_dev),
+            train_score=self.__score_func(model, X_forest, y_forest),
+            dev_score=self.__score_func(model, X_omp, y_omp),
             test_score=self.__score_func(model, self._dataset.X_test, self._dataset.y_test),
-            train_score_base=self.__score_func_base(model, self._dataset.X_train, self._dataset.y_train),
-            dev_score_base=self.__score_func_base(model, self._dataset.X_dev, self._dataset.y_dev),
+            train_score_base=self.__score_func_base(model, X_forest, y_forest),
+            dev_score_base=self.__score_func_base(model, X_omp, y_omp),
             test_score_base=self.__score_func_base(model, self._dataset.X_test, self._dataset.y_test),
             score_metric=self._score_metric_name,
-            base_score_metric=self._base_score_metric_name
+            base_score_metric=self._base_score_metric_name,
+            train_coherence=self._evaluate_predictions(X_forest, aggregation_function=np.max, selected_trees=selected_trees),
+            dev_coherence=self._evaluate_predictions(X_omp, aggregation_function=np.max, selected_trees=selected_trees),
+            test_coherence=self._evaluate_predictions(self._dataset.X_test, aggregation_function=np.max, selected_trees=selected_trees),
+            train_correlation=self._evaluate_predictions(X_forest, aggregation_function=np.mean, selected_trees=selected_trees),
+            dev_correlation=self._evaluate_predictions(X_omp, aggregation_function=np.mean, selected_trees=selected_trees),
+            test_correlation=self._evaluate_predictions(self._dataset.X_test, aggregation_function=np.mean, selected_trees=selected_trees),
+            train_strength=self._compute_forest_strength(X_forest, y_forest, strength_metric, selected_trees),
+            dev_strength=self._compute_forest_strength(X_omp, y_omp, strength_metric, selected_trees),
+            test_strength=self._compute_forest_strength(self._dataset.X_test, self._dataset.y_test, strength_metric, selected_trees)
         )
         results.save(models_dir)
         self._logger.info("Base performance on test: {}".format(results.test_score_base))
@@ -185,26 +253,30 @@ class Trainer(object):
         self._logger.info("Base performance on dev: {}".format(results.dev_score_base))
         self._logger.info("Performance on dev: {}".format(results.dev_score))
 
+        self._logger.info(f'test_coherence: {results.test_coherence}')
+        self._logger.info(f'test_correlation: {results.test_correlation}')
+        self._logger.info(f'test_strength: {results.test_strength}')
+
         if type(model) not in [RandomForestRegressor, RandomForestClassifier]:
             results = ModelRawResults(
                 model_weights='',
                 training_time=self._end_time - self._begin_time,
                 datetime=datetime.datetime.now(),
-                train_score=self.__score_func(model, self._dataset.X_train, self._dataset.y_train, False),
-                dev_score=self.__score_func(model, self._dataset.X_dev, self._dataset.y_dev, False),
+                train_score=self.__score_func(model, X_forest, y_forest, False),
+                dev_score=self.__score_func(model, X_omp, y_omp, False),
                 test_score=self.__score_func(model, self._dataset.X_test, self._dataset.y_test, False),
-                train_score_base=self.__score_func_base(model, self._dataset.X_train, self._dataset.y_train),
-                dev_score_base=self.__score_func_base(model, self._dataset.X_dev, self._dataset.y_dev),
+                train_score_base=self.__score_func_base(model, X_forest, y_forest),
+                dev_score_base=self.__score_func_base(model, X_omp, y_omp),
                 test_score_base=self.__score_func_base(model, self._dataset.X_test, self._dataset.y_test),
                 score_metric=self._score_metric_name,
                 base_score_metric=self._base_score_metric_name
             )
             results.save(models_dir+'_no_weights')
             self._logger.info("Base performance on test without weights: {}".format(results.test_score_base))
-            self._logger.info("Performance on test: {}".format(results.test_score))
+            self._logger.info("Performance on test without weights: {}".format(results.test_score))
 
             self._logger.info("Base performance on train without weights: {}".format(results.train_score_base))
-            self._logger.info("Performance on train: {}".format(results.train_score))
+            self._logger.info("Performance on train without weights: {}".format(results.train_score))
 
             self._logger.info("Base performance on dev without weights: {}".format(results.dev_score_base))
-            self._logger.info("Performance on dev: {}".format(results.dev_score))
+            self._logger.info("Performance on dev without weights: {}".format(results.dev_score))
diff --git a/code/bolsonaro/visualization/plotter.py b/code/bolsonaro/visualization/plotter.py
index 5a5f72ad9fade836dcfed3c2ef6f452653dcf3d1..7d3154e1b15f153c85ef0be360fd990f4395fde5 100644
--- a/code/bolsonaro/visualization/plotter.py
+++ b/code/bolsonaro/visualization/plotter.py
@@ -51,6 +51,7 @@ class Plotter(object):
 
     @staticmethod
     def plot_mean_and_CI(ax, mean, lb, ub, x_value, color_mean=None, facecolor=None, label=None):
+        #print(x_value, mean, lb, ub)
         # plot the shaded range of the confidence intervals
         ax.fill_between(x_value, ub, lb, facecolor=facecolor, alpha=.5)
         # plot the mean on top
@@ -105,7 +106,7 @@ class Plotter(object):
 
     @staticmethod
     def plot_stage2_losses(file_path, all_experiment_scores, x_value,
-        xlabel, ylabel, all_labels, title):
+        xlabel, ylabel, all_labels, title, filter_num=-1):
 
         fig, ax = plt.subplots()
 
@@ -124,13 +125,14 @@ class Plotter(object):
             # Compute the mean and the std for the CI
             mean_experiment_scores = np.average(experiment_scores, axis=0)
             std_experiment_scores = np.std(experiment_scores, axis=0)
+
             # Plot the score curve with the CI
             Plotter.plot_mean_and_CI(
                 ax=ax,
                 mean=mean_experiment_scores,
                 lb=mean_experiment_scores + std_experiment_scores,
                 ub=mean_experiment_scores - std_experiment_scores,
-                x_value=x_value,
+                x_value=x_value[:filter_num] if len(mean_experiment_scores) == filter_num else x_value,
                 color_mean=colors[i],
                 facecolor=colors[i],
                 label=all_labels[i]
diff --git a/code/compute_results.py b/code/compute_results.py
index d77779e82e295b5e76c0347551c20b8ef258a546..23e3db3ad7c95e5f5732b4d09e945ce53dfd4467 100644
--- a/code/compute_results.py
+++ b/code/compute_results.py
@@ -2,11 +2,49 @@ from bolsonaro.models.model_raw_results import ModelRawResults
 from bolsonaro.visualization.plotter import Plotter
 from bolsonaro import LOG_PATH
 from bolsonaro.error_handling.logger_factory import LoggerFactory
+from bolsonaro.data.dataset_parameters import DatasetParameters
+from bolsonaro.data.dataset_loader import DatasetLoader
 
 import argparse
 import pathlib
 from dotenv import find_dotenv, load_dotenv
 import os
+import numpy as np
+import pickle
+from tqdm import tqdm
+from scipy.stats import rankdata
+from pyrsa.vis.colors import rdm_colormap
+from pyrsa.rdm.calc import calc_rdm
+from pyrsa.data.dataset import Dataset
+import matplotlib.pyplot as plt
+from sklearn.manifold import MDS
+from sklearn.preprocessing import normalize
+
+
+def vect2triu(dsm_vect, dim=None):
+    if not dim:
+        # sqrt(X²) \simeq sqrt(X²-X) -> sqrt(X²) = ceil(sqrt(X²-X))
+        dim = int(np.ceil(np.sqrt(dsm_vect.shape[1] * 2)))
+    dsm = np.zeros((dim,dim))
+    ind_up = np.triu_indices(dim, 1)
+    dsm[ind_up] = dsm_vect
+    return dsm
+
+def triu2full(dsm_triu):
+    dsm_full = np.copy(dsm_triu)
+    ind_low = np.tril_indices(dsm_full.shape[0], -1)
+    dsm_full[ind_low] = dsm_full.T[ind_low]
+    return dsm_full
+
+def plot_RDM(rdm, file_path, condition_number):
+    rdm = triu2full(vect2triu(rdm, condition_number))
+    fig = plt.figure()
+    cols = rdm_colormap(condition_number)
+    plt.imshow(rdm, cmap=cols)
+    plt.colorbar()
+    plt.savefig(file_path, dpi=200)
+    plt.close()
+
 
 
 def retreive_extracted_forest_sizes_number(models_dir, experiment_id):
@@ -17,7 +55,7 @@ def retreive_extracted_forest_sizes_number(models_dir, experiment_id):
     extracted_forest_sizes_root_path = experiment_seed_path + os.sep + 'extracted_forest_sizes'
     return len(os.listdir(extracted_forest_sizes_root_path))
 
-def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_dir, experiment_id, weights=True):
+def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_dir, experiment_id, weights=True, extracted_forest_sizes=list()):
     experiment_id_path = models_dir + os.sep + str(experiment_id) # models/{experiment_id}
     experiment_seed_root_path = experiment_id_path + os.sep + 'seeds' # models/{experiment_id}/seeds
 
@@ -45,10 +83,11 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d
         experiment_dev_scores[seed] = list()
         experiment_test_scores[seed] = list()
 
-        # List the forest sizes in models/{experiment_id}/seeds/{seed}/extracted_forest_sizes
-        extracted_forest_sizes = os.listdir(extracted_forest_sizes_root_path)
-        extracted_forest_sizes = [nb_tree for nb_tree in extracted_forest_sizes if not 'no_weights' in nb_tree ]
-        extracted_forest_sizes.sort(key=int)
+        if len(extracted_forest_sizes) == 0:
+            # List the forest sizes in models/{experiment_id}/seeds/{seed}/extracted_forest_sizes
+            extracted_forest_sizes = os.listdir(extracted_forest_sizes_root_path)
+            extracted_forest_sizes = [nb_tree for nb_tree in extracted_forest_sizes if not 'no_weights' in nb_tree ]
+            extracted_forest_sizes.sort(key=int)
         all_extracted_forest_sizes.append(list(map(int, extracted_forest_sizes)))
         for extracted_forest_size in extracted_forest_sizes:
             # models/{experiment_id}/seeds/{seed}/extracted_forest_sizes/{extracted_forest_size}
@@ -148,6 +187,180 @@ def extract_weights_across_seeds(models_dir, results_dir, experiment_id):
 
     return experiment_weights
 
+def extract_correlations_across_seeds(models_dir, results_dir, experiment_id):
+    experiment_id_path = models_dir + os.sep + str(experiment_id) # models/{experiment_id}
+    experiment_seed_root_path = experiment_id_path + os.sep + 'seeds' # models/{experiment_id}/seeds
+    experiment_correlations = dict()
+
+    # For each seed results stored in models/{experiment_id}/seeds
+    seeds = os.listdir(experiment_seed_root_path)
+    seeds.sort(key=int)
+    for seed in seeds:
+        experiment_seed_path = experiment_seed_root_path + os.sep + seed # models/{experiment_id}/seeds/{seed}
+        extracted_forest_sizes_root_path = experiment_seed_path + os.sep + 'extracted_forest_sizes' # models/{experiment_id}/seeds/{seed}/forest_size
+
+        # {{seed}:[]}
+        experiment_correlations[seed] = list()
+
+        # List the forest sizes in models/{experiment_id}/seeds/{seed}/extracted_forest_sizes
+        extracted_forest_sizes = os.listdir(extracted_forest_sizes_root_path)
+        extracted_forest_sizes = [nb_tree for nb_tree in extracted_forest_sizes if not 'no_weights' in nb_tree ]
+        extracted_forest_sizes.sort(key=int)
+        for extracted_forest_size in extracted_forest_sizes:
+            # models/{experiment_id}/seeds/{seed}/extracted_forest_sizes/{extracted_forest_size}
+            extracted_forest_size_path = extracted_forest_sizes_root_path + os.sep + extracted_forest_size
+            # Load models/{experiment_id}/seeds/{seed}/extracted_forest_sizes/{extracted_forest_size}/model_raw_results.pickle file
+            model_raw_results = ModelRawResults.load(extracted_forest_size_path)
+            experiment_correlations[seed].append(model_raw_results.correlation)
+
+    return experiment_correlations
+
+def extract_coherences_across_seeds(models_dir, results_dir, experiment_id):
+    experiment_id_path = models_dir + os.sep + str(experiment_id) # models/{experiment_id}
+    experiment_seed_root_path = experiment_id_path + os.sep + 'seeds' # models/{experiment_id}/seeds
+    experiment_coherences = dict()
+
+    # For each seed results stored in models/{experiment_id}/seeds
+    seeds = os.listdir(experiment_seed_root_path)
+    seeds.sort(key=int)
+    for seed in seeds:
+        experiment_seed_path = experiment_seed_root_path + os.sep + seed # models/{experiment_id}/seeds/{seed}
+        extracted_forest_sizes_root_path = experiment_seed_path + os.sep + 'extracted_forest_sizes' # models/{experiment_id}/seeds/{seed}/forest_size
+
+        # {{seed}:[]}
+        experiment_coherences[seed] = list()
+
+        # List the forest sizes in models/{experiment_id}/seeds/{seed}/extracted_forest_sizes
+        extracted_forest_sizes = os.listdir(extracted_forest_sizes_root_path)
+        extracted_forest_sizes = [nb_tree for nb_tree in extracted_forest_sizes if not 'no_weights' in nb_tree ]
+        extracted_forest_sizes.sort(key=int)
+        for extracted_forest_size in extracted_forest_sizes:
+            # models/{experiment_id}/seeds/{seed}/extracted_forest_sizes/{extracted_forest_size}
+            extracted_forest_size_path = extracted_forest_sizes_root_path + os.sep + extracted_forest_size
+            # Load models/{experiment_id}/seeds/{seed}/extracted_forest_sizes/{extracted_forest_size}/model_raw_results.pickle file
+            model_raw_results = ModelRawResults.load(extracted_forest_size_path)
+            experiment_coherences[seed].append(model_raw_results.coherence)
+
+    return experiment_coherences
+
+def extract_selected_trees_scores_across_seeds(models_dir, results_dir, experiment_id, weighted=False):
+    experiment_id_path = models_dir + os.sep + str(experiment_id) # models/{experiment_id}
+    experiment_seed_root_path = experiment_id_path + os.sep + 'seeds' # models/{experiment_id}/seeds
+    experiment_selected_trees_scores = dict()
+
+    print(f'[extract_selected_trees_scores_across_seeds] experiment_id: {experiment_id}')
+
+    # For each seed results stored in models/{experiment_id}/seeds
+    seeds = os.listdir(experiment_seed_root_path)
+    seeds.sort(key=int)
+    with tqdm(seeds) as seed_bar:
+        for seed in seed_bar:
+            seed_bar.set_description(f'seed: {seed}')
+            experiment_seed_path = experiment_seed_root_path + os.sep + seed # models/{experiment_id}/seeds/{seed}
+            extracted_forest_sizes_root_path = experiment_seed_path + os.sep + 'extracted_forest_sizes' # models/{experiment_id}/seeds/{seed}/forest_size
+
+            dataset_parameters = DatasetParameters.load(experiment_seed_path, experiment_id)
+            dataset = DatasetLoader.load(dataset_parameters)
+
+            # {{seed}:[]}
+            experiment_selected_trees_scores[seed] = list()
+
+            # List the forest sizes in models/{experiment_id}/seeds/{seed}/extracted_forest_sizes
+            extracted_forest_sizes = os.listdir(extracted_forest_sizes_root_path)
+            extracted_forest_sizes = [nb_tree for nb_tree in extracted_forest_sizes if not 'no_weights' in nb_tree]
+            extracted_forest_sizes.sort(key=int)
+            with tqdm(extracted_forest_sizes) as extracted_forest_size_bar:
+                for extracted_forest_size in extracted_forest_size_bar:
+                    # models/{experiment_id}/seeds/{seed}/extracted_forest_sizes/{extracted_forest_size}
+                    extracted_forest_size_path = extracted_forest_sizes_root_path + os.sep + extracted_forest_size
+                    selected_trees = None
+                    with open(os.path.join(extracted_forest_size_path, 'selected_trees.pickle'), 'rb') as file:
+                        selected_trees = pickle.load(file)
+                    selected_trees_test_scores = np.array([tree.score(dataset.X_test, dataset.y_test) for tree in selected_trees])
+
+                    if weighted:
+                        model_raw_results = ModelRawResults.load(extracted_forest_size_path)
+                        weights = model_raw_results.model_weights
+                        if type(weights) != str:
+                            weights = weights[weights != 0]
+                            score = np.mean(np.square(selected_trees_test_scores * weights))
+                        else:
+                            score = np.mean(np.square(selected_trees_test_scores))
+                    else:
+                        score = np.mean(selected_trees_test_scores)
+                    experiment_selected_trees_scores[seed].append(score)
+                    extracted_forest_size_bar.set_description(f'extracted_forest_size: {extracted_forest_size} - test_score: {round(score, 2)}')
+                    extracted_forest_size_bar.update(1)
+            seed_bar.update(1)
+
+    return experiment_selected_trees_scores
+
+def extract_selected_trees_across_seeds(models_dir, results_dir, experiment_id):
+    experiment_id_path = models_dir + os.sep + str(experiment_id) # models/{experiment_id}
+    experiment_seed_root_path = experiment_id_path + os.sep + 'seeds' # models/{experiment_id}/seeds
+    experiment_selected_trees = dict()
+
+    # For each seed results stored in models/{experiment_id}/seeds
+    seeds = os.listdir(experiment_seed_root_path)
+    seeds.sort(key=int)
+    with tqdm(seeds) as seed_bar:
+        for seed in seed_bar:
+            seed_bar.set_description(f'seed: {seed}')
+            experiment_seed_path = experiment_seed_root_path + os.sep + seed # models/{experiment_id}/seeds/{seed}
+            extracted_forest_sizes_root_path = experiment_seed_path + os.sep + 'extracted_forest_sizes' # models/{experiment_id}/seeds/{seed}/forest_size
+
+            dataset_parameters = DatasetParameters.load(experiment_seed_path, experiment_id)
+            dataset = DatasetLoader.load(dataset_parameters)
+
+            # {{seed}:[]}
+            experiment_selected_trees[seed] = list()
+
+            # List the forest sizes in models/{experiment_id}/seeds/{seed}/extracted_forest_sizes
+            extracted_forest_sizes = os.listdir(extracted_forest_sizes_root_path)
+            extracted_forest_sizes = [nb_tree for nb_tree in extracted_forest_sizes if not 'no_weights' in nb_tree ]
+            extracted_forest_sizes.sort(key=int)
+            all_selected_trees_predictions = list()
+            with tqdm(extracted_forest_sizes) as extracted_forest_size_bar:
+                for extracted_forest_size in extracted_forest_size_bar:
+                    # models/{experiment_id}/seeds/{seed}/extracted_forest_sizes/{extracted_forest_size}
+                    extracted_forest_size_path = extracted_forest_sizes_root_path + os.sep + extracted_forest_size
+                    selected_trees = None
+                    with open(os.path.join(extracted_forest_size_path, 'selected_trees.pickle'), 'rb') as file:
+                        selected_trees = pickle.load(file)
+                    #test_score = np.mean([tree.score(dataset.X_test, dataset.y_test) for tree in selected_trees])
+                    #selected_trees_predictions = np.array([tree.score(dataset.X_test, dataset.y_test) for tree in selected_trees])
+                    selected_trees_predictions = [tree.predict(dataset.X_test) for tree in selected_trees]
+                    extracted_forest_size_bar.set_description(f'extracted_forest_size: {extracted_forest_size}')
+                    #experiment_selected_trees[seed].append(test_score)
+                    extracted_forest_size_bar.update(1)
+                    selected_trees_predictions = np.array(selected_trees_predictions)
+                    selected_trees_predictions = normalize(selected_trees_predictions)
+
+                    """mds = MDS(len(selected_trees_predictions))
+                    Y = mds.fit_transform(selected_trees_predictions)
+                    plt.scatter(Y[:, 0], Y[:, 1])
+                    plt.savefig(f'test_mds_{experiment_id}.png')"""
+
+                    if int(extracted_forest_size) <= 267:
+                        forest_RDM = calc_rdm(Dataset(selected_trees_predictions), method='euclidean').get_vectors()
+                        ranked_forest_RDM = np.apply_along_axis(rankdata, 1, forest_RDM.reshape(1, -1))
+
+                        from scipy.cluster import hierarchy
+                        RDM = triu2full(vect2triu(ranked_forest_RDM, int(extracted_forest_size)))
+                        Z = hierarchy.linkage(RDM, 'average')
+                        fig = plt.figure(figsize=(15, 8))
+                        dn = hierarchy.dendrogram(Z)
+                        plt.savefig(f'test_dendrogram_scores_id:{experiment_id}_seed:{seed}_size:{extracted_forest_size}.png')
+                        plt.close()
+
+                        plot_RDM(
+                            rdm=ranked_forest_RDM,
+                            file_path=f'test_scores_ranked_forest_RDM_id:{experiment_id}_seed:{seed}_size:{extracted_forest_size}.png',
+                            condition_number=len(selected_trees_predictions)
+                        )
+            break
+            seed_bar.update(1)
+    return experiment_selected_trees
 
 if __name__ == "__main__":
     # get environment variables in .env
@@ -158,6 +371,8 @@ if __name__ == "__main__":
     DEFAULT_PLOT_WEIGHT_DENSITY = False
     DEFAULT_WO_LOSS_PLOTS = False
     DEFAULT_PLOT_PREDS_COHERENCE = False
+    DEFAULT_PLOT_FOREST_STRENGTH = False
+    DEFAULT_COMPUTE_SELECTED_TREES_RDMS = False
 
     parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
     parser.add_argument('--stage', nargs='?', type=int, required=True, help='Specify the stage number among [1, 5].')
@@ -172,6 +387,9 @@ if __name__ == "__main__":
     parser.add_argument('--plot_weight_density', action='store_true', default=DEFAULT_PLOT_WEIGHT_DENSITY, help='Plot the weight density. Only working for regressor models for now.')
     parser.add_argument('--wo_loss_plots', action='store_true', default=DEFAULT_WO_LOSS_PLOTS, help='Do not compute the loss plots.')
     parser.add_argument('--plot_preds_coherence', action='store_true', default=DEFAULT_PLOT_PREDS_COHERENCE, help='Plot the coherence of the prediction trees.')
+    parser.add_argument('--plot_preds_correlation', action='store_true', default=DEFAULT_PLOT_PREDS_COHERENCE, help='Plot the correlation of the prediction trees.')
+    parser.add_argument('--plot_forest_strength', action='store_true', default=DEFAULT_PLOT_FOREST_STRENGTH, help='Plot the strength of the extracted forest.')
+    parser.add_argument('--compute_selected_trees_rdms', action='store_true', default=DEFAULT_COMPUTE_SELECTED_TREES_RDMS, help='Representation similarity analysis of the selected trees')
     args = parser.parse_args()
 
     if args.stage not in list(range(1, 6)):
@@ -437,6 +655,16 @@ if __name__ == "__main__":
         all_labels = list()
         all_scores = list()
 
+        """extracted_forest_sizes = np.unique(np.around(1000 *
+            np.linspace(0, 1.0,
+            30 + 1,
+            endpoint=True)[1:]).astype(np.int)).tolist()"""
+
+        #extracted_forest_sizes = [4, 7, 11, 14, 18, 22, 25, 29, 32, 36, 40, 43, 47, 50, 54, 58, 61, 65, 68, 72, 76, 79, 83, 86, 90, 94, 97, 101, 104, 108]
+
+        #extracted_forest_sizes = [str(forest_size) for forest_size in extracted_forest_sizes]
+        extracted_forest_sizes= list()
+
         # base_with_params
         logger.info('Loading base_with_params experiment scores...')
         base_with_params_train_scores, base_with_params_dev_scores, base_with_params_test_scores, \
@@ -447,53 +675,74 @@ if __name__ == "__main__":
         logger.info('Loading random_with_params experiment scores...')
         random_with_params_train_scores, random_with_params_dev_scores, random_with_params_test_scores, \
             with_params_extracted_forest_sizes, random_with_params_experiment_score_metric = \
-            extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, int(args.experiment_ids[1]))
+            extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, int(args.experiment_ids[1]),
+            extracted_forest_sizes=extracted_forest_sizes)
         # omp_with_params
         logger.info('Loading omp_with_params experiment scores...')
         omp_with_params_train_scores, omp_with_params_dev_scores, omp_with_params_test_scores, _, \
             omp_with_params_experiment_score_metric = extract_scores_across_seeds_and_extracted_forest_sizes(
-                args.models_dir, args.results_dir, int(args.experiment_ids[2]))
+                args.models_dir, args.results_dir, int(args.experiment_ids[2]), extracted_forest_sizes=extracted_forest_sizes)
         #omp_with_params_without_weights
         logger.info('Loading omp_with_params without weights experiment scores...')
         omp_with_params_without_weights_train_scores, omp_with_params_without_weights_dev_scores, omp_with_params_without_weights_test_scores, _, \
             omp_with_params_experiment_score_metric = extract_scores_across_seeds_and_extracted_forest_sizes(
-                args.models_dir, args.results_dir, int(args.experiment_ids[2]), weights=False)
-
-        all_labels = ['base', 'random', 'omp', 'omp_without_weights']
+                args.models_dir, args.results_dir, int(args.experiment_ids[2]), weights=False, extracted_forest_sizes=extracted_forest_sizes)
+
+        """print(omp_with_params_dev_scores)
+        import sys
+        sys.exit(0)"""
+
+        all_labels = ['base', 'random', 'omp', 'omp_wo_weights']
+        #all_labels = ['base', 'random', 'omp']
+        omp_with_params_test_scores_new = dict()
+        filter_num = -1
+        """filter_num = 9
+        for key, value in omp_with_params_test_scores.items():
+            omp_with_params_test_scores_new[key] = value[:filter_num]"""
         all_scores = [base_with_params_test_scores, random_with_params_test_scores, omp_with_params_test_scores,
             omp_with_params_without_weights_test_scores]
+        #all_scores = [base_with_params_dev_scores, random_with_params_dev_scores, omp_with_params_dev_scores,
+        #    omp_with_params_without_weights_dev_scores]
+        #all_scores = [base_with_params_train_scores, random_with_params_train_scores, omp_with_params_train_scores,
+        #    omp_with_params_without_weights_train_scores]
 
         for i in range(3, len(args.experiment_ids)):
             if 'kmeans' in args.experiment_ids[i]:
                 label = 'kmeans'
-            elif 'similarity' in args.experiment_ids[i]:
-                label = 'similarity'
+            elif 'similarity_similarities' in args.experiment_ids[i]:
+                label = 'similarity_similarities'
+            elif 'similarity_predictions' in args.experiment_ids[i]:
+                label = 'similarity_predictions'
             elif 'ensemble' in args.experiment_ids[i]:
                 label = 'ensemble'
+            elif 'omp_distillation' in args.experiment_ids[i]:
+                label = 'omp_distillation'
             else:
                 logger.error('Invalid value encountered')
                 continue
 
             logger.info(f'Loading {label} experiment scores...')
             current_experiment_id = int(args.experiment_ids[i].split('=')[1])
-            _, _, current_test_scores, _, _ = extract_scores_across_seeds_and_extracted_forest_sizes(
+            current_train_scores, current_dev_scores, current_test_scores, _, _ = extract_scores_across_seeds_and_extracted_forest_sizes(
                 args.models_dir, args.results_dir, current_experiment_id)
             all_labels.append(label)
             all_scores.append(current_test_scores)
+            #all_scores.append(current_train_scores)
+            #all_scores.append(current_dev_scores)
 
-        output_path = os.path.join(args.results_dir, args.dataset_name, 'stage5')
+        output_path = os.path.join(args.results_dir, args.dataset_name, 'stage5_test_train,dev')
         pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)
 
         Plotter.plot_stage2_losses(
-            file_path=output_path + os.sep + f"losses_{'-'.join(all_labels)}.png",
+            file_path=output_path + os.sep + f"losses_{'-'.join(all_labels)}_test_train,dev.png",
             all_experiment_scores=all_scores,
             all_labels=all_labels,
             x_value=with_params_extracted_forest_sizes,
             xlabel='Number of trees extracted',
             ylabel=base_with_params_experiment_score_metric,
-            title='Loss values of {}\nusing best params of previous stages'.format(args.dataset_name))
+            title='Loss values of {}\nusing best params of previous stages'.format(args.dataset_name), filter_num=filter_num)
 
-    if args.plot_weight_density:
+    """if args.plot_weight_density:
         root_output_path = os.path.join(args.results_dir, args.dataset_name, f'stage{args.stage}')
 
         if args.stage == 1:
@@ -528,6 +777,118 @@ if __name__ == "__main__":
         for (experiment_label, experiment_id) in omp_experiment_ids:
             logger.info(f'Computing weight density plot for experiment {experiment_label}...')
             experiment_weights = extract_weights_across_seeds(args.models_dir, args.results_dir, experiment_id)
-            Plotter.weight_density(experiment_weights, os.path.join(root_output_path, f'weight_density_{experiment_label}.png'))
+            Plotter.weight_density(experiment_weights, os.path.join(root_output_path, f'weight_density_{experiment_label}.png'))"""
+
+    if args.plot_weight_density:
+        logger.info(f'Computing weight density plot for experiment {experiment_label}...')
+        experiment_weights = extract_weights_across_seeds(args.models_dir, args.results_dir, experiment_id)
+        Plotter.weight_density(experiment_weights, os.path.join(root_output_path, f'weight_density_{experiment_label}.png'))
+    if args.plot_preds_coherence:
+        root_output_path = os.path.join(args.results_dir, args.dataset_name, f'stage5_new')
+        pathlib.Path(root_output_path).mkdir(parents=True, exist_ok=True)
+        all_labels = ['random', 'omp', 'kmeans', 'similarity_similarities', 'similarity_predictions', 'ensemble']
+        _, _, _, with_params_extracted_forest_sizes, _ = \
+            extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, 2)
+        coherence_values = [extract_coherences_across_seeds(args.models_dir, args.results_dir, i) for i in args.experiment_ids]
+        Plotter.plot_stage2_losses(
+            file_path=root_output_path + os.sep + f"coherences_{'-'.join(all_labels)}.png",
+            all_experiment_scores=coherence_values,
+            all_labels=all_labels,
+            x_value=with_params_extracted_forest_sizes,
+            xlabel='Number of trees extracted',
+            ylabel='Coherence',
+            title='Coherence values of {}'.format(args.dataset_name))
+        logger.info(f'Computing preds coherence plot...')
+
+    if args.plot_preds_correlation:
+        root_output_path = os.path.join(args.results_dir, args.dataset_name, f'stage5_new')
+        pathlib.Path(root_output_path).mkdir(parents=True, exist_ok=True)
+        all_labels = ['none', 'random', 'omp', 'kmeans', 'similarity_similarities', 'similarity_predictions', 'ensemble']
+        _, _, _, with_params_extracted_forest_sizes, _ = \
+            extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, 2)
+        correlation_values = [extract_correlations_across_seeds(args.models_dir, args.results_dir, i) for i in args.experiment_ids]
+        Plotter.plot_stage2_losses(
+            file_path=root_output_path + os.sep + f"correlations_{'-'.join(all_labels)}.png",
+            all_experiment_scores=correlation_values,
+            all_labels=all_labels,
+            x_value=with_params_extracted_forest_sizes,
+            xlabel='Number of trees extracted',
+            ylabel='correlation',
+            title='correlation values of {}'.format(args.dataset_name))
+        logger.info(f'Computing preds correlation plot...')
+
+    if args.plot_forest_strength:
+        root_output_path = os.path.join(args.results_dir, args.dataset_name, f'stage5_strength')
+        pathlib.Path(root_output_path).mkdir(parents=True, exist_ok=True)
+
+        _, _, _, with_params_extracted_forest_sizes, _ = \
+                extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, 2)
+        #all_selected_trees_scores = list()
+        #all_selected_trees_weighted_scores = list()
+        """with tqdm(args.experiment_ids) as experiment_id_bar:
+            for experiment_id in experiment_id_bar:
+                experiment_id_bar.set_description(f'experiment_id: {experiment_id}')
+                selected_trees_scores, selected_trees_weighted_scores = extract_selected_trees_scores_across_seeds(
+                    args.models_dir, args.results_dir, experiment_id)
+                all_selected_trees_scores.append(selected_trees_scores)
+                all_selected_trees_weighted_scores.append(selected_trees_weighted_scores)
+                experiment_id_bar.update(1)"""
+
+        #random_selected_trees_scores = extract_selected_trees_scores_across_seeds(
+        #    args.models_dir, args.results_dir, 2, weighted=True)
+
+        omp_selected_trees_scores = extract_selected_trees_scores_across_seeds(
+            args.models_dir, args.results_dir, 3, weighted=True)
+
+        similarity_similarities_selected_trees_scores = extract_selected_trees_scores_across_seeds(
+            args.models_dir, args.results_dir, 6, weighted=True)
+
+        #similarity_predictions_selected_trees_scores = extract_selected_trees_scores_across_seeds(
+        #    args.models_dir, args.results_dir, 7)
+
+        ensemble_selected_trees_scores = extract_selected_trees_scores_across_seeds(
+            args.models_dir, args.results_dir, 8, weighted=True)
+
+        # kmeans=5
+        # similarity_similarities=6
+        # similarity_predictions=7
+        # ensemble=8
+
+        all_selected_trees_scores = [random_selected_trees_scores, omp_selected_trees_scores, similarity_similarities_selected_trees_scores,
+            ensemble_selected_trees_scores]
+
+        with open('california_housing_forest_strength_scores.pickle', 'wb') as file:
+            pickle.dump(all_selected_trees_scores, file)
+
+        """with open('forest_strength_scores.pickle', 'rb') as file:
+            all_selected_trees_scores = pickle.load(file)"""
+
+        all_labels = ['random', 'omp', 'similarity_similarities', 'ensemble']
+
+        Plotter.plot_stage2_losses(
+            file_path=root_output_path + os.sep + f"forest_strength_{'-'.join(all_labels)}_v2_sota.png",
+            all_experiment_scores=all_selected_trees_scores,
+            all_labels=all_labels,
+            x_value=with_params_extracted_forest_sizes,
+            xlabel='Number of trees extracted',
+            ylabel='Mean of selected tree scores on test set',
+            title='Forest strength of {}'.format(args.dataset_name))
+
+    if args.compute_selected_trees_rdms:
+        root_output_path = os.path.join(args.results_dir, args.dataset_name, f'stage5_strength')
+        pathlib.Path(root_output_path).mkdir(parents=True, exist_ok=True)
+
+        _, _, _, with_params_extracted_forest_sizes, _ = \
+                extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, 2)
+        all_selected_trees_scores = list()
+        with tqdm([2, 3, 8]) as experiment_id_bar:
+            for experiment_id in experiment_id_bar:
+                experiment_id_bar.set_description(f'experiment_id: {experiment_id}')
+                all_selected_trees_scores.append(extract_selected_trees_across_seeds(
+                    args.models_dir, args.results_dir, experiment_id))
+                experiment_id_bar.update(1)
+
+        with open('forest_strength_scores.pickle', 'rb') as file:
+            all_selected_trees_scores = pickle.load(file)
 
     logger.info('Done.')
diff --git a/code/playground/nn_omp.py b/code/playground/nn_omp.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/code/prepare_models.py b/code/prepare_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cd9ea37033063652e15e0e1c84432b831b6562e
--- /dev/null
+++ b/code/prepare_models.py
@@ -0,0 +1,32 @@
+import pathlib
+import glob2
+import os
+import shutil
+from tqdm import tqdm
+
+
+if __name__ == "__main__":
+    models_source_path = 'models'
+    models_destination_path = 'bolsonaro_models_25-03-20'
+    #datasets = ['boston', 'diabetes', 'linnerud', 'breast_cancer', 'california_housing', 'diamonds',
+    #    'steel-plates', 'kr-vs-kp', 'kin8nm', 'spambase', 'gamma', 'lfw_pairs']
+    datasets = ['kin8nm']
+
+    pathlib.Path(models_destination_path).mkdir(parents=True, exist_ok=True)
+
+    with tqdm(datasets) as dataset_bar:
+        for dataset in dataset_bar:
+            dataset_bar.set_description(dataset)
+            found_paths = glob2.glob(os.path.join(models_source_path, dataset, 'stage5_new',
+                '**', 'model_raw_results.pickle'), recursive=True)
+            pathlib.Path(os.path.join(models_destination_path, dataset)).mkdir(parents=True, exist_ok=True)
+            with tqdm(found_paths) as found_paths_bar:
+                for path in found_paths_bar:
+                    found_paths_bar.set_description(path)
+                    new_path = path.replace(f'models/{dataset}/stage5_new/', '')
+                    (new_path, filename) = os.path.split(new_path)
+                    new_path = os.path.join(models_destination_path, dataset, new_path)
+                    pathlib.Path(new_path).mkdir(parents=True, exist_ok=True)
+                    shutil.copyfile(src=path, dst=os.path.join(new_path, filename))
+                    found_paths_bar.update(1)
+            dataset_bar.update(1)
diff --git a/code/train.py b/code/train.py
index 95498cdf03a894ca8c8cf91d6702acc6aef1a799..10dbf7354837cab803202a8307c671f0def0f274 100644
--- a/code/train.py
+++ b/code/train.py
@@ -55,11 +55,6 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb
 
     trainer = Trainer(dataset)
 
-    if parameters['extraction_strategy'] == 'ensemble':
-        library = EnsembleSelectionForestRegressor.generate_library(dataset.X_train, dataset.y_train, random_state=seed)
-    else:
-        library = None
-
     if parameters['extraction_strategy'] == 'random':
         pretrained_model_parameters = ModelParameters(
             extracted_forest_size=parameters['forest_size'],
@@ -70,12 +65,12 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb
             hyperparameters=hyperparameters,
             extraction_strategy=parameters['extraction_strategy']
         )
-        pretrained_estimator = ModelFactory.build(dataset.task, pretrained_model_parameters, library=library)
-        pretraned_trainer = Trainer(dataset)
-        pretraned_trainer.init(pretrained_estimator, subsets_used=parameters['subsets_used'])
+        pretrained_estimator = ModelFactory.build(dataset.task, pretrained_model_parameters)
+        pretrained_trainer = Trainer(dataset)
+        pretrained_trainer.init(pretrained_estimator, subsets_used=parameters['subsets_used'])
         pretrained_estimator.fit(
-            X=pretraned_trainer._X_forest,
-            y=pretraned_trainer._y_forest
+            X=pretrained_trainer._X_forest,
+            y=pretrained_trainer._y_forest
         )
     else:
         pretrained_estimator = None
@@ -84,8 +79,9 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb
     if parameters['extraction_strategy'] != 'none':
         with tqdm_joblib(tqdm(total=len(parameters['extracted_forest_size']), disable=not verbose)) as extracted_forest_size_job_pb:
             Parallel(n_jobs=-1)(delayed(extracted_forest_size_job)(extracted_forest_size_job_pb, parameters['extracted_forest_size'][i],
-                models_dir, seed, parameters, dataset, hyperparameters, experiment_id, trainer, library,
-                pretrained_estimator=pretrained_estimator, pretrained_model_parameters=pretrained_model_parameters)
+                models_dir, seed, parameters, dataset, hyperparameters, experiment_id, trainer,
+                pretrained_estimator=pretrained_estimator, pretrained_model_parameters=pretrained_model_parameters,
+                use_distillation=parameters['extraction_strategy'] == 'omp_distillation')
                 for i in range(len(parameters['extracted_forest_size'])))
     else:
         forest_size = hyperparameters['n_estimators']
@@ -97,11 +93,11 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb
         if os.path.isdir(sub_models_dir):
             sub_models_dir_files = os.listdir(sub_models_dir)
             for file_name in sub_models_dir_files:
-                if '.pickle' != os.path.splitext(file_name)[1]:
-                    continue
-                else:
+                if file_name == 'model_raw_results.pickle':
                     already_exists = os.path.getsize(os.path.join(sub_models_dir, file_name)) > 0
                     break
+                else:
+                    continue
         if already_exists:
             logger.info('Base forest result already exists. Skipping...')
         else:
@@ -117,7 +113,7 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb
             )
             model_parameters.save(sub_models_dir, experiment_id)
 
-            model = ModelFactory.build(dataset.task, model_parameters, library=library)
+            model = ModelFactory.build(dataset.task, model_parameters)
 
             trainer.init(model, subsets_used=parameters['subsets_used'])
             trainer.train(model)
@@ -126,8 +122,8 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb
     seed_job_pb.update(1)
 
 def extracted_forest_size_job(extracted_forest_size_job_pb, extracted_forest_size, models_dir,
-    seed, parameters, dataset, hyperparameters, experiment_id, trainer, library,
-    pretrained_estimator=None, pretrained_model_parameters=None):
+    seed, parameters, dataset, hyperparameters, experiment_id, trainer,
+    pretrained_estimator=None, pretrained_model_parameters=None, use_distillation=False):
 
     logger = LoggerFactory.create(LOG_PATH, 'training_seed{}_extracted_forest_size{}_ti{}'.format(
         seed, extracted_forest_size, threading.get_ident()))
@@ -140,11 +136,11 @@ def extracted_forest_size_job(extracted_forest_size_job_pb, extracted_forest_siz
     if os.path.isdir(sub_models_dir):
         sub_models_dir_files = os.listdir(sub_models_dir)
         for file_name in sub_models_dir_files:
-            if '.pickle' != os.path.splitext(file_name)[1]:
-                continue
-            else:
+            if file_name == 'model_raw_results.pickle':
                 already_exists = os.path.getsize(os.path.join(sub_models_dir, file_name)) > 0
                 break
+            else:
+                continue
     if already_exists:
         logger.info(f'Extracted forest {extracted_forest_size} result already exists. Skipping...')
         return
@@ -162,13 +158,14 @@ def extracted_forest_size_job(extracted_forest_size_job_pb, extracted_forest_siz
             extraction_strategy=parameters['extraction_strategy']
         )
         model_parameters.save(sub_models_dir, experiment_id)
-        model = ModelFactory.build(dataset.task, model_parameters, library=library)
+        model = ModelFactory.build(dataset.task, model_parameters)
     else:
         model = copy.deepcopy(pretrained_estimator)
         pretrained_model_parameters.save(sub_models_dir, experiment_id)
 
     trainer.init(model, subsets_used=parameters['subsets_used'])
-    trainer.train(model, extracted_forest_size=extracted_forest_size)
+    trainer.train(model, extracted_forest_size=extracted_forest_size, seed=seed,
+        use_distillation=use_distillation)
     trainer.compute_results(model, sub_models_dir)
 
 """
@@ -235,7 +232,7 @@ if __name__ == "__main__":
     parser.add_argument('--skip_best_hyperparams', action='store_true', default=DEFAULT_SKIP_BEST_HYPERPARAMS, help='Do not use the best hyperparameters if there exist.')
     parser.add_argument('--save_experiment_configuration', nargs='+', default=None, help='Save the experiment parameters specified in the command line in a file. Args: {{stage_num}} {{name}}')
     parser.add_argument('--job_number', nargs='?', type=int, default=DEFAULT_JOB_NUMBER, help='Specify the number of job used during the parallelisation across seeds.')
-    parser.add_argument('--extraction_strategy', nargs='?', type=str, default=DEFAULT_EXTRACTION_STRATEGY, help='Specify the strategy to apply to extract the trees from the forest. Either omp, random, none, similarity, kmeans, ensemble.')
+    parser.add_argument('--extraction_strategy', nargs='?', type=str, default=DEFAULT_EXTRACTION_STRATEGY, help='Specify the strategy to apply to extract the trees from the forest. Either omp, random, none, similarity_similarities, similarity_predictions, kmeans, ensemble.')
     parser.add_argument('--overwrite', action='store_true', default=DEFAULT_OVERWRITE, help='Overwrite the experiment id')
     args = parser.parse_args()
 
@@ -246,8 +243,8 @@ if __name__ == "__main__":
     else:
         parameters = args.__dict__
 
-    if parameters['extraction_strategy'] not in ['omp', 'random', 'none', 'similarity', 'kmeans', 'ensemble']:
-        raise ValueError('Specified extraction strategy {} is not supported.'.format(parameters.extraction_strategy))
+    if parameters['extraction_strategy'] not in ['omp', 'omp_distillation', 'random', 'none', 'similarity_similarities', 'similarity_predictions', 'kmeans', 'ensemble']:
+        raise ValueError('Specified extraction strategy {} is not supported.'.format(parameters['extraction_strategy']))
 
     pathlib.Path(parameters['models_dir']).mkdir(parents=True, exist_ok=True)
 
diff --git a/code/vizualisation/csv_to_figure.py b/code/vizualisation/csv_to_figure.py
new file mode 100644
index 0000000000000000000000000000000000000000..244314ba9fcbb49ab48762c09dd63e4d69df9cf5
--- /dev/null
+++ b/code/vizualisation/csv_to_figure.py
@@ -0,0 +1,172 @@
+from dotenv import load_dotenv, find_dotenv
+from pathlib import Path
+import os
+import pandas as pd
+import numpy as np
+import plotly.graph_objects as go
+import plotly.io as pio
+
+
+lst_skip_strategy = ["None", "OMP Distillation", "OMP Distillation w/o weights"]
+# lst_skip_subset = ["train/dev"]
+lst_task_train_dev = ["coherence", "correlation"]
+
+tasks = [
+    # "train_score",
+    # "dev_score",
+    # "test_score",
+    "coherence",
+    "correlation",
+    # "negative-percentage"
+]
+
+dct_score_metric_fancy = {
+    "accuracy_score": "% Accuracy",
+    "mean_squared_error": "MSE"
+}
+
+pio.templates.default = "plotly_white"
+
+dct_color_by_strategy = {
+    "OMP": (255, 0, 0), # red
+    "OMP Distillation": (255, 0, 0), # red
+    "OMP Distillation w/o weights": (255, 128, 0), # orange
+    "OMP w/o weights": (255, 128, 0), # orange
+    "Random": (0, 0, 0), # black
+    "Zhang Similarities": (255, 255, 0), # jaune
+    'Zhang Predictions': (128, 0, 128), # turquoise
+    'Ensemble': (0, 0, 255), # blue
+    "Kmeans": (0, 255, 0) # red
+}
+
+dct_dash_by_strategy = {
+    "OMP": None,
+    "OMP Distillation": "dash",
+    "OMP Distillation w/o weights": "dash",
+    "OMP w/o weights": None,
+    "Random": "dot",
+    "Zhang Similarities": "dash",
+    'Zhang Predictions': "dash",
+    'Ensemble': "dash",
+    "Kmeans": "dash"
+}
+
+def add_trace_from_df(df, fig):
+    df.sort_values(by="forest_size", inplace=True)
+    df_groupby_forest_size = df.groupby(['forest_size'])
+    forest_sizes = list(df_groupby_forest_size["forest_size"].mean().values)
+    mean_value = df_groupby_forest_size[task].mean().values
+    std_value = df_groupby_forest_size[task].std().values
+    std_value_upper = list(mean_value + std_value)
+    std_value_lower = list(mean_value - std_value)
+    # print(df_strat)
+    fig.add_trace(go.Scatter(x=forest_sizes, y=mean_value,
+                             mode='lines',
+                             name=strat,
+                             line=dict(dash=dct_dash_by_strategy[strat], color="rgb{}".format(dct_color_by_strategy[strat]))
+                             ))
+
+    fig.add_trace(go.Scatter(
+        x=forest_sizes + forest_sizes[::-1],
+        y=std_value_upper + std_value_lower[::-1],
+        fill='toself',
+        showlegend=False,
+        fillcolor='rgba{}'.format(dct_color_by_strategy[strat] + tpl_transparency),
+        line_color='rgba(255,255,255,0)',
+        name=strat
+    ))
+
+tpl_transparency = (0.1,)
+
+if __name__ == "__main__":
+
+    load_dotenv(find_dotenv('.env'))
+    dir_name = "bolsonaro_models_25-03-20"
+    dir_path = Path(os.environ["project_dir"]) / "results" / dir_name
+
+    out_dir = Path(os.environ["project_dir"]) / "reports/figures" / dir_name
+
+    input_dir_file = dir_path / "results.csv"
+    df_results = pd.read_csv(open(input_dir_file, 'rb'))
+
+    datasets = set(df_results["dataset"].values)
+    strategies = set(df_results["strategy"].values)
+    subsets = set(df_results["subset"].values)
+
+    for task in tasks:
+        for data_name in datasets:
+            df_data = df_results[df_results["dataset"] == data_name]
+            score_metric_name = df_data["score_metric"].values[0]
+
+            fig = go.Figure()
+
+            ##################
+            # all techniques #
+            ##################
+            for strat in strategies:
+                if strat in lst_skip_strategy:
+                    continue
+                df_strat = df_data[df_data["strategy"] == strat]
+                df_strat = df_strat[df_strat["subset"] == "train+dev/train+dev"]
+
+                if "OMP" in strat:
+                    ###########################
+                    # traitement avec weights #
+                    ###########################
+                    df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
+                    if data_name == "Boston":
+                        df_strat_wo_weights = df_strat_wo_weights[df_strat_wo_weights["forest_size"] < 400]
+                    add_trace_from_df(df_strat_wo_weights, fig)
+
+                #################################
+                # traitement general wo_weights #
+                #################################
+                if "OMP" in strat:
+                    df_strat_wo_weights = df_strat[df_strat["wo_weights"] == True]
+                else:
+                    df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
+
+                if "OMP" in strat:
+                    strat = "{} w/o weights".format(strat)
+
+                add_trace_from_df(df_strat_wo_weights, fig)
+
+            title = "{} {}".format(task, data_name)
+            yaxis_title = "% negative weights" if task == "negative-percentage" else dct_score_metric_fancy[score_metric_name]
+            fig.update_layout(barmode='group',
+                              title=title,
+                              xaxis_title="# Selected Trees",
+                              yaxis_title=yaxis_title,
+                              font=dict(
+                                  # family="Courier New, monospace",
+                                  size=24,
+                                  color="black"
+                              ),
+                                showlegend = False,
+                                margin = dict(
+                                    l=1,
+                                    r=1,
+                                    b=1,
+                                    t=1,
+                                    # pad=4
+                                ),
+                              legend=dict(
+                                  traceorder="normal",
+                                  font=dict(
+                                      family="sans-serif",
+                                      size=24,
+                                      color="black"
+                                  ),
+                                  # bgcolor="LightSteelBlue",
+                                  # bordercolor="Black",
+                                  borderwidth=1,
+                              )
+                              )
+            # fig.show()
+            sanitize = lambda x: x.replace(" ", "_").replace("/", "_").replace("+", "_")
+            filename = sanitize(title)
+            output_dir = out_dir / sanitize(task)
+            output_dir.mkdir(parents=True, exist_ok=True)
+            fig.write_image(str((output_dir / filename).absolute()) + ".png")
+
+            # exit()
diff --git a/code/vizualisation/csv_to_table.py b/code/vizualisation/csv_to_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..440e5fc8454732e40af3cce667a04d4677d032af
--- /dev/null
+++ b/code/vizualisation/csv_to_table.py
@@ -0,0 +1,310 @@
+import copy
+
+from dotenv import load_dotenv, find_dotenv
+from pathlib import Path
+import os
+import pandas as pd
+import numpy as np
+from pprint import pprint
+import plotly.graph_objects as go
+import plotly.io as pio
+from collections import defaultdict
+
+lst_skip_strategy = ["None", "OMP Distillation", "OMP Distillation w/o weights"]
+lst_skip_task = ["correlation", "coherence"]
+# lst_skip_task = []
+lst_skip_subset = ["train/dev"]
+# lst_skip_subset = []
+
+tasks = [
+    # "train_score",
+    # "dev_score",
+    "test_score",
+    # "coherence",
+    # "correlation"
+]
+
+dct_score_metric_fancy = {
+    "accuracy_score": "% Accuracy",
+    "mean_squared_error": "MSE"
+}
+dct_score_metric_best_fct = {
+    "accuracy_score": np.argmax,
+    "mean_squared_error": np.argmin
+}
+
+dct_data_short = {
+    "Spambase": "Spambase",
+    "Diamonds": "Diamonds",
+    "Diabetes": "Diabetes",
+    "Steel Plates": "Steel P.",
+    "KR-VS-KP": "KR-VS-KP",
+    "Breast Cancer": "Breast C.",
+    "Kin8nm": "Kin8nm",
+    "LFW Pairs": "LFW P.",
+    "Gamma": "Gamma",
+    "California Housing": "California H.",
+    "Boston": "Boston",
+}
+
+dct_data_best = {
+    "Spambase": np.max,
+    "Diamonds": np.min,
+    "Diabetes": np.min,
+    "Steel Plates": np.max,
+    "KR-VS-KP": np.max,
+    "Breast Cancer": np.max,
+    "Kin8nm": np.min,
+    "LFW Pairs": np.max,
+    "Gamma": np.max,
+    "California Housing": np.min,
+    "Boston": np.min,
+}
+dct_data_metric = {
+    "Spambase": "Acc.",
+    "Diamonds": "MSE",
+    "Diabetes": "MSE",
+    "Steel Plates": "Acc.",
+    "KR-VS-KP": "Acc.",
+    "Breast Cancer": "Acc.",
+    "Kin8nm": "MSE",
+    "LFW Pairs": "Acc.",
+    "Gamma": "Acc.",
+    "California Housing": "MSE",
+    "Boston": "MSE",
+}
+
+
+
+def get_max_from_df(df, best_fct):
+    nb_to_consider = 10
+    df.sort_values(by="forest_size", inplace=True)
+    df_groupby_forest_size = df.groupby(['forest_size'])
+    forest_sizes = list(df_groupby_forest_size["forest_size"].mean().values)[:nb_to_consider]
+    mean_value = df_groupby_forest_size[task].mean().values[:nb_to_consider]
+    std_value = df_groupby_forest_size[task].std().values[:nb_to_consider]
+
+    try:
+        argmax = best_fct(mean_value)
+    except:
+        print("no results", strat, data_name, task, subset_name)
+        return -1, -1, -1
+
+    max_mean = mean_value[argmax]
+    max_std = std_value[argmax]
+    max_forest_size = forest_sizes[argmax]
+
+    return max_forest_size, max_mean, max_std
+
+
+
+if __name__ == "__main__":
+
+    load_dotenv(find_dotenv('.env'))
+    dir_name = "bolsonaro_models_25-03-20"
+    dir_path = Path(os.environ["project_dir"]) / "results" / dir_name
+
+    out_dir = Path(os.environ["project_dir"]) / "reports/figures" / dir_name
+
+    input_dir_file = dir_path / "results.csv"
+    df_results = pd.read_csv(open(input_dir_file, 'rb'))
+
+    datasets = set(df_results["dataset"].values)
+    strategies = sorted(list(set(df_results["strategy"].values) - set(lst_skip_strategy)))
+    subsets = set(df_results["subset"].values)
+
+    r"""
+    \begin{table}[!h]
+    \centering
+    \begin{tabular}{l{}}
+    \toprule
+    \multicolumn{1}{c}{\textbf{Dataset}} & \textbf{Data dim.} $\datadim$        & \textbf{\# classes} & \textbf{Train size} $\nexamples$ & \textbf{Test size} $\nexamples'$ \\ \midrule
+    \texttt{MNIST}~\cite{lecun-mnisthandwrittendigit-2010}                      & 784    & 10        & 60 000    & 10 000               \\ %\hline
+    \texttt{Kddcup99}~\cite{Dua:2019}                                           & 116    & 23      & 4 893 431      & 5 000               \\ 
+    \bottomrule
+    \end{tabular}
+    \caption{Main features of the datasets. Discrete, unordered attributes for dataset Kddcup99 have been encoded as one-hot attributes.}
+    \label{table:data}
+    \end{table}
+    """
+
+
+    for task in tasks:
+        if task in lst_skip_task:
+            continue
+
+        dct_data_lst_tpl_results = defaultdict(lambda: [])
+
+        lst_strats = []
+        for data_name in datasets:
+            df_data = df_results[df_results["dataset"] == data_name]
+            score_metric_name = df_data["score_metric"].values[0]
+
+            for subset_name in subsets:
+                if subset_name in lst_skip_subset:
+                    continue
+                df_subset = df_data[df_data["subset"] == subset_name]
+
+                ##################
+                # all techniques #
+                ##################
+                for strat in strategies:
+                    if strat in lst_skip_strategy:
+                        continue
+                    df_strat = df_subset[df_subset["strategy"] == strat]
+
+                    if "OMP" in strat:
+                        ###########################
+                        # traitement avec weights #
+                        ###########################
+                        df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
+                        if data_name == "Boston" and subset_name == "train+dev/train+dev":
+                            df_strat_wo_weights = df_strat_wo_weights[df_strat_wo_weights["forest_size"] < 400]
+                        dct_data_lst_tpl_results[data_name].append(get_max_from_df(df_strat_wo_weights, dct_score_metric_best_fct[score_metric_name]))
+                        if strat not in lst_strats: lst_strats.append(strat)
+
+                    if "OMP" in strat and subset_name == "train/dev":
+                        continue
+                    elif "Random" not in strat and subset_name == "train/dev":
+                        continue
+
+                    #################################
+                    # traitement general wo_weights #
+                    #################################
+                    if "Random" in strat:
+                        df_strat_wo_weights = df_strat[df_strat["wo_weights"] == False]
+                    else:
+                        df_strat_wo_weights = df_strat[df_strat["wo_weights"] == True]
+
+                    if "OMP" in strat:
+                        strat = "{} w/o weights".format(strat)
+
+                    dct_data_lst_tpl_results[data_name].append(get_max_from_df(df_strat_wo_weights, dct_score_metric_best_fct[score_metric_name]))
+                    if strat not in lst_strats: lst_strats.append(strat)
+
+                title = "{} {} {}".format(task, data_name, subset_name)
+
+                # fig.show()
+                sanitize = lambda x: x.replace(" ", "_").replace("/", "_").replace("+", "_")
+                filename = sanitize(title)
+                # output_dir = out_dir / sanitize(subset_name) / sanitize(task)
+                # output_dir.mkdir(parents=True, exist_ok=True)
+                # fig.write_image(str((output_dir / filename).absolute()) + ".png")
+
+
+        # pprint(dct_data_lst_tpl_results)
+
+        lst_data_ordered = [
+            "Diamonds",
+            "Diabetes",
+            "Kin8nm",
+            "California Housing",
+            "Boston",
+            "Spambase",
+            "Steel Plates",
+            "KR-VS-KP",
+            "Breast Cancer",
+            "LFW Pairs",
+            "Gamma"
+        ]
+
+
+        arr_results_str = np.empty((len(lst_strats)+1, len(datasets) + 1 ), dtype="object")
+        nb_spaces = 25
+        dct_strat_str = defaultdict(lambda: [])
+        s_empty = "{}" + " "*(nb_spaces-2) + " & "
+        arr_results_str[0][0] = s_empty
+        # arr_results_str[0][1] = s_empty
+        for idx_data, data_name in enumerate(lst_data_ordered):
+            lst_tpl_results = dct_data_lst_tpl_results[data_name]
+            data_name_short = dct_data_short[data_name]
+            s_data_tmp = "{}".format(data_name_short)
+            s_data_tmp += "({})".format(dct_data_metric[data_name])
+            # s_data_tmp = "\\texttt{{ {} }}".format(data_name_short)
+            # s_data_tmp = "\\multicolumn{{2}}{{c}}{{ \\texttt{{ {} }} }}".format(data_name)
+            s_data_tmp += " "*(nb_spaces - len(data_name_short))
+            s_data_tmp += " & "
+            arr_results_str[0, idx_data + 1] = s_data_tmp
+
+            array_results = np.array(lst_tpl_results)
+            best_result_perf = dct_data_best[data_name](array_results[:, 1])
+            best_result_perf_indexes = np.argwhere(array_results[:, 1] == best_result_perf)
+
+            copye_array_results = copy.deepcopy(array_results)
+            if dct_data_best[data_name] is np.min:
+                copye_array_results[best_result_perf_indexes] = np.inf
+            else:
+                copye_array_results[best_result_perf_indexes] = -np.inf
+
+            best_result_perf_2 = dct_data_best[data_name](copye_array_results[:, 1])
+            best_result_perf_indexes_2 = np.argwhere(copye_array_results[:, 1] == best_result_perf_2)
+
+            best_result_prune = np.min(array_results[:, 0])
+            best_result_prune_indexes = np.argwhere(array_results[:, 0] == best_result_prune)
+
+            for idx_strat, tpl_results in enumerate(array_results):
+                # str_strat = "\\texttt{{ {} }}".format(lst_strats[idx_strat])
+                # str_strat = "\\multicolumn{{2}}{{c}}{{ \\texttt{{ {} }} }}".format(lst_strats[idx_strat])
+                # str_strat = "\\multicolumn{{2}}{{c}}{{ \\thead{{ \\texttt{{ {} }} }} }}".format("}\\\\ \\texttt{".join(lst_strats[idx_strat].split(" ", 1)))
+                str_strat = "\\multicolumn{{2}}{{c}}{{ \\thead{{ {} }} }} ".format("\\\\".join(lst_strats[idx_strat].split(" ", 1)))
+                str_strat += " " * (nb_spaces - len(str_strat)) + " & "
+                arr_results_str[idx_strat+1, 0] =  str_strat
+
+                # str_header = " {} & #tree &".format(dct_data_metric[data_name])
+                # arr_results_str[idx_strat + 1, 1] = str_header
+
+                best_forest_size = tpl_results[0]
+                best_mean = tpl_results[1]
+                best_std = tpl_results[2]
+                if dct_data_metric[data_name] == "Acc.":
+                    str_perf = "{:.2f}\\%".format(best_mean * 100)
+                else:
+                    str_perf = "{:.3E}".format(best_mean)
+
+                str_prune = "{:d}".format(int(best_forest_size))
+
+                if idx_strat in best_result_perf_indexes:
+                    # str_formating = "\\textbf{{ {} }}".format(str_result_loc)
+                    str_formating = "\\textbf[{}]"
+                    # str_formating = "\\textbf{{ {:.3E} }}(\\~{:.3E})".format(best_mean, best_std)
+                elif idx_strat in best_result_perf_indexes_2:
+                    str_formating = "\\underline[{}]"
+                    # str_formating = "\\underline{{ {:.3E} }}(\\~{:.3E})".format(best_mean, best_std)
+                else:
+                    str_formating = "{}"
+                    # str_formating = "{:.3E}(~{:.3E})".format(best_mean, best_std)
+
+                if idx_strat in best_result_prune_indexes:
+                    str_formating = str_formating.format("\\textit[{}]")
+                    # str_prune = " & \\textit{{ {:d} }}".format(int(best_forest_size))
+                # else:
+                #     str_prune = " & {:d}".format(int(best_forest_size))
+                str_result = str_formating.format(str_perf) + " & " + str_formating.format(str_prune)
+                str_result += " "*(nb_spaces - len(str_result))
+                str_result = str_result.replace("[", "{").replace("]", "}")
+
+                arr_results_str[idx_strat+1, idx_data+1] = str_result + " & "
+                dct_strat_str[lst_strats[idx_strat]].append(str_result)
+
+        arr_results_str = arr_results_str.T
+        for idx_lin, lin in enumerate(arr_results_str):
+            if idx_lin == 1:
+                print("\\midrule")
+            if idx_lin == 6:
+                print("\\midrule")
+            if lst_data_ordered[idx_lin-1] == "Diamonds":
+                print("%", end="")
+            line_print = " ".join(list(lin))
+            line_print = line_print.rstrip(" &") + "\\\\"
+            print(line_print)
+
+        # s_data = s_data.rstrip(" &") + "\\\\"
+        # print(s_data)
+        # for strat, lst_str_results in dct_strat_str.items():
+        #     str_strat = "\\texttt{{ {} }}".format(strat)
+        #     str_strat += " "*(nb_spaces - len(str_strat))
+        #     str_strat += " & " + " & ".join(lst_str_results)
+        #     str_strat += "\\\\"
+        #     print(str_strat)
+
+                # exit()
diff --git a/code/vizualisation/results_to_csv.py b/code/vizualisation/results_to_csv.py
new file mode 100644
index 0000000000000000000000000000000000000000..669451b1f812f7f83584670790196601f1a5f40e
--- /dev/null
+++ b/code/vizualisation/results_to_csv.py
@@ -0,0 +1,154 @@
+from pathlib import Path
+import os
+import pandas as pd
+from pprint import pprint
+import pickle
+from collections import defaultdict
+import numpy as np
+
+from dotenv import load_dotenv, find_dotenv
+
+
+dct_experiment_id_subset = dict((str(idx), "train+dev/train+dev") for idx in range(1, 9))
+dct_experiment_id_subset.update(dict((str(idx), "train/dev") for idx in range(9, 17)))
+
+NONE = 'None'
+Random = 'Random'
+OMP = 'OMP'
+OMP_Distillation = 'OMP Distillation'
+Kmeans = 'Kmeans'
+Zhang_Similarities = 'Zhang Similarities'
+Zhang_Predictions = 'Zhang Predictions'
+Ensemble = 'Ensemble'
+dct_experiment_id_technique = {"1": NONE,
+                               "2": Random,
+                               "3": OMP,
+                               "4": OMP_Distillation,
+                               "5": Kmeans,
+                               "6": Zhang_Similarities,
+                               "7": Zhang_Predictions,
+                               "8": Ensemble,
+                               "9": NONE,
+                               "10": Random,
+                               "11": OMP,
+                               "12": OMP_Distillation,
+                               "13": Kmeans,
+                               "14": Zhang_Similarities,
+                               "15": Zhang_Predictions,
+                               "16": Ensemble
+                               }
+
+
+dct_dataset_fancy = {
+    "boston": "Boston",
+    "breast_cancer": "Breast Cancer",
+    "california_housing": "California Housing",
+    "diabetes": "Diabetes",
+    "diamonds": "Diamonds",
+    "digits": "Digits",
+    "iris": "Iris",
+    "kin8nm": "Kin8nm",
+    "kr-vs-kp": "KR-VS-KP",
+    "olivetti_faces": "Olivetti Faces",
+    "spambase": "Spambase",
+    "steel-plates": "Steel Plates",
+    "wine": "Wine",
+    "gamma": "Gamma",
+    "lfw_pairs": "LFW Pairs"
+}
+
+skip_attributes = ["datetime"]
+set_no_coherence = set()
+set_no_corr = set()
+
+if __name__ == "__main__":
+
+    load_dotenv(find_dotenv('.env'))
+    dir_name = "results/bolsonaro_models_25-03-20"
+    dir_path = Path(os.environ["project_dir"]) / dir_name
+
+    output_dir_file = dir_path / "results.csv"
+
+    dct_results = defaultdict(lambda: [])
+
+    for root, dirs, files in os.walk(dir_path, topdown=False):
+        for file_str in files:
+            if file_str == "results.csv":
+                continue
+            path_dir = Path(root)
+            path_file = path_dir / file_str
+            print(path_file)
+            try:
+                with open(path_file, 'rb') as pickle_file:
+                    obj_results = pickle.load(pickle_file)
+            except:
+                print("problem loading pickle file {}".format(path_file))
+
+            path_dir_split = str(path_dir).split("/")
+
+            bool_wo_weights = "no_weights" in str(path_file)
+
+            if bool_wo_weights:
+                forest_size = int(path_dir_split[-1].split("_")[0])
+            else:
+                forest_size = int(path_dir_split[-1])
+
+            seed = int(path_dir_split[-3])
+            id_xp = str(path_dir_split[-5])
+            dataset = str(path_dir_split[-6])
+
+            dct_results["forest_size"].append(forest_size)
+            dct_results["seed"].append(seed)
+            dct_results["dataset"].append(dct_dataset_fancy[dataset])
+            dct_results["subset"].append(dct_experiment_id_subset[id_xp])
+            dct_results["strategy"].append(dct_experiment_id_technique[id_xp])
+            dct_results["wo_weights"].append(bool_wo_weights)
+
+            for key_result, val_result in obj_results.items():
+                if key_result in skip_attributes:
+                    continue
+                if key_result == "model_weights":
+                    if val_result == "":
+                        dct_results["negative-percentage"].append(None)
+                    else:
+                        lt_zero = val_result < 0
+                        gt_zero = val_result > 0
+
+                        nb_lt_zero = np.sum(lt_zero)
+                        nb_gt_zero = np.sum(gt_zero)
+
+                        percentage_lt_zero = nb_lt_zero / (nb_gt_zero + nb_lt_zero)
+                        dct_results["negative-percentage"].append(percentage_lt_zero)
+                if val_result == "":
+                    # print(key_result, val_result)
+                    val_result = None
+                if key_result == "coherence" and val_result is None:
+                    set_no_coherence.add(id_xp)
+                if key_result == "correlation" and val_result is None:
+                    set_no_corr.add(id_xp)
+
+                dct_results[key_result].append(val_result)
+
+                # class 'dict'>: {'model_weights': '',
+                #                 'training_time': 0.0032033920288085938,
+                #                 'datetime': datetime.datetime(2020, 3, 25, 0, 28, 34, 938400),
+                #                 'train_score': 1.0,
+                #                 'dev_score': 0.978021978021978,
+                #                 'test_score': 0.9736842105263158,
+                #                 'train_score_base': 1.0,
+                #                 'dev_score_base': 0.978021978021978,
+                #                 'test_score_base': 0.9736842105263158,
+                #                 'score_metric': 'accuracy_score',
+                #                 'base_score_metric': 'accuracy_score',
+                #                 'coherence': 0.9892031711775613,
+                #                 'correlation': 0.9510700193340448}
+
+            # print(path_file)
+
+    print("coh", set_no_coherence, len(set_no_coherence))
+    print("cor", set_no_corr, len(set_no_corr))
+
+
+    final_df = pd.DataFrame.from_dict(dct_results)
+    final_df.to_csv(output_dir_file)
+    print(final_df)
diff --git a/requirements.txt b/requirements.txt
index 38a47c2beeff7ee073c27b9dd7ed9cabfbc12c4f..ef5021d7e1d513be852d7af1bbfae18e95ca08ac 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,15 +1,79 @@
-# local package
--e .
-
-# external requirements
-click
-Sphinx
-coverage
-awscli
-flake8
-pytest
-scikit-learn
-git+git://github.com/darenr/scikit-optimize@master
-python-dotenv
-matplotlib
-pandas
+alabaster==0.7.12
+attrs==19.3.0
+awscli==1.16.272
+Babel==2.7.0
+backcall==0.1.0
+-e git+git@gitlab.lis-lab.fr:luc.giffon/bolsonaro.git@bbad0e522d6b4b392f1926fa935f2a7fac093411#egg=bolsonaro
+botocore==1.13.8
+certifi==2019.11.28
+chardet==3.0.4
+Click==7.0
+colorama==0.4.1
+coverage==4.5.4
+cycler==0.10.0
+decorator==4.4.2
+docutils==0.15.2
+entrypoints==0.3
+flake8==3.7.9
+idna==2.8
+imagesize==1.1.0
+importlib-metadata==1.5.0
+ipython==7.13.0
+ipython-genutils==0.2.0
+jedi==0.16.0
+Jinja2==2.10.3
+jmespath==0.9.4
+joblib==0.14.0
+kiwisolver==1.1.0
+MarkupSafe==1.1.1
+matplotlib==3.1.1
+mccabe==0.6.1
+mkl-fft==1.0.14
+mkl-random==1.1.0
+mkl-service==2.3.0
+more-itertools==8.2.0
+numpy==1.17.3
+packaging==20.3
+pandas==0.25.3
+parso==0.6.2
+pexpect==4.8.0
+pickleshare==0.7.5
+plotly==4.5.2
+pluggy==0.13.1
+prompt-toolkit==3.0.3
+psutil==5.7.0
+ptyprocess==0.6.0
+py==1.8.1
+pyaml==20.3.1
+pyasn1==0.4.7
+pycodestyle==2.5.0
+pyflakes==2.1.1
+Pygments==2.6.1
+pyparsing==2.4.5
+pytest==5.4.1
+python-dateutil==2.8.1
+python-dotenv==0.10.3
+pytz==2019.3
+PyYAML==5.1.2
+requests==2.22.0
+retrying==1.3.3
+rsa==3.4.2
+s3transfer==0.2.1
+scikit-learn==0.21.3
+scikit-optimize==0.7.4
+scipy==1.3.1
+six==1.12.0
+snowballstemmer==2.0.0
+Sphinx==2.2.1
+sphinxcontrib-applehelp==1.0.1
+sphinxcontrib-devhelp==1.0.1
+sphinxcontrib-htmlhelp==1.0.2
+sphinxcontrib-jsmath==1.0.1
+sphinxcontrib-qthelp==1.0.2
+sphinxcontrib-serializinghtml==1.1.3
+tornado==6.0.3
+tqdm==4.43.0
+traitlets==4.3.3
+urllib3==1.25.6
+wcwidth==0.1.8
+zipp==2.2.0
diff --git a/scripts/run_stage5_experiments.sh b/scripts/run_stage5_experiments.sh
index 4bd371187b186004585640645e9da738aa752c47..fe8c3d64d4d196b1718ab1fced20390004c85a61 100755
--- a/scripts/run_stage5_experiments.sh
+++ b/scripts/run_stage5_experiments.sh
@@ -1,16 +1,27 @@
 #!/bin/bash
 core_number=5
-core_number_sota=50
-walltime=1:00
+core_number_sota=5
+walltime=5:00
 walltime_sota=5:00
 seeds='1 2 3 4 5'
 
-for dataset in kin8nm kr-vs-kp spambase steel-plates diabetes diamonds boston california_housing
+for dataset in boston diabetes linnerud breast_cancer california_housing diamonds steel-plates kr-vs-kp kin8nm spambase musk gamma lfw_pairs
 do
-    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=none --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=1 --models_dir=models/$dataset/stage5 --subsets_used train+dev,train+dev"
-    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=random --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=2 --models_dir=models/$dataset/stage5 --subsets_used train+dev,train+dev"
-    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=omp --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=3 --models_dir=models/$dataset/stage5 --subsets_used train+dev,train+dev"
-    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=similarity --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=4 --models_dir=models/$dataset/stage5 --subsets_used train+dev,train+dev"
-    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=kmeans --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=5 --models_dir=models/$dataset/stage5 --subsets_used train+dev,train+dev"
-    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=ensemble --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=6 --models_dir=models/$dataset/stage5 --subsets_used train+dev,train+dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=none --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=1 --models_dir=models/$dataset/stage5_new --subsets_used train+dev,train+dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=random --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=2 --models_dir=models/$dataset/stage5_new --subsets_used train+dev,train+dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=omp --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=3 --models_dir=models/$dataset/stage5_new --subsets_used train+dev,train+dev --normalize_D"
+    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=omp_distillation --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=4 --models_dir=models/$dataset/stage5_new --subsets_used train+dev,train+dev --normalize_D"
+    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=kmeans --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=5 --models_dir=models/$dataset/stage5_new --subsets_used train+dev,train+dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=similarity_similarities --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=6 --models_dir=models/$dataset/stage5_new --subsets_used train+dev,train+dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=similarity_predictions --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=7 --models_dir=models/$dataset/stage5_new --subsets_used train+dev,train+dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=ensemble --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=8 --models_dir=models/$dataset/stage5_new --subsets_used train+dev,train+dev"
+
+    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=none --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=9 --models_dir=models/$dataset/stage5_new --subsets_used train,dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=random --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=10 --models_dir=models/$dataset/stage5_new --subsets_used train,dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=omp --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=11 --models_dir=models/$dataset/stage5_new --subsets_used train,dev --normalize_D"
+    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=omp_distillation --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=12 --models_dir=models/$dataset/stage5_new --subsets_used train,dev --normalize_D"
+    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=kmeans --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=13 --models_dir=models/$dataset/stage5_new --subsets_used train,dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=similarity_similarities --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=14 --models_dir=models/$dataset/stage5_new --subsets_used train,dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=similarity_predictions --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=15 --models_dir=models/$dataset/stage5_new --subsets_used train,dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=ensemble --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=16 --models_dir=models/$dataset/stage5_new --subsets_used train,dev"
 done