diff --git a/code/bolsonaro/models/model_raw_results.py b/code/bolsonaro/models/model_raw_results.py
index 7d6106513fd62830974006b8349f8d1d90e67a7a..3f7af5fcd31c1eb105a3dd39695e1ddc69f38676 100644
--- a/code/bolsonaro/models/model_raw_results.py
+++ b/code/bolsonaro/models/model_raw_results.py
@@ -10,7 +10,10 @@ class ModelRawResults(object):
         datetime, train_score, dev_score, test_score,
         train_score_base, dev_score_base,
         test_score_base, score_metric, base_score_metric,
-        coherence='', correlation=''):
+        #coherence='', correlation=''):
+        train_coherence='', dev_coherence='', test_coherence='',
+        train_correlation='', dev_correlation='', test_correlation='',
+        train_strength='', dev_strength='', test_strength=''):
 
         self._model_weights = model_weights
         self._training_time = training_time
@@ -23,8 +26,17 @@ class ModelRawResults(object):
         self._test_score_base = test_score_base
         self._score_metric = score_metric
         self._base_score_metric = base_score_metric
-        self._coherence = coherence
-        self._correlation = correlation
+        """self._coherence = coherence
+        self._correlation = correlation"""
+        self._train_coherence = train_coherence
+        self._dev_coherence = dev_coherence
+        self._test_coherence = test_coherence
+        self._train_correlation = train_correlation
+        self._dev_correlation = dev_correlation
+        self._test_correlation = test_correlation
+        self._train_strength = train_strength
+        self._dev_strength = dev_strength
+        self._test_strength = test_strength
 
     @property
     def model_weights(self):
@@ -70,13 +82,49 @@ class ModelRawResults(object):
     def base_score_metric(self):
         return self._base_score_metric
 
-    @property
+    """@property
     def coherence(self):
         return self._coherence
 
     @property
     def correlation(self):
-        return self._correlation
+        return self._correlation"""
+
+    @property
+    def train_coherence(self):
+        return self._train_coherence
+
+    @property
+    def dev_coherence(self):
+        return self._dev_coherence
+
+    @property
+    def test_coherence(self):
+        return self._test_coherence
+
+    @property
+    def train_correlation(self):
+        return self._train_correlation
+
+    @property
+    def dev_correlation(self):
+        return self._dev_correlation
+
+    @property
+    def test_correlation(self):
+        return self._test_correlation
+
+    @property
+    def train_strength(self):
+        return self._train_strength
+
+    @property
+    def dev_strength(self):
+        return self._dev_strength
+
+    @property
+    def test_strength(self):
+        return self._test_strength
 
     def save(self, models_dir):
         if not os.path.exists(models_dir):
diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py
index 7761c534f69e415b4708bea76cf4917017a6dbe0..78f2c082e4a9c20dfe7b6b5dfa2d5d49aca99cc2 100644
--- a/code/bolsonaro/trainer.py
+++ b/code/bolsonaro/trainer.py
@@ -39,7 +39,6 @@ class Trainer(object):
             else classification_score_metric.__name__
         self._base_score_metric_name = base_regression_score_metric.__name__ if dataset.task == Task.REGRESSION \
             else base_classification_score_metric.__name__
-        self._selected_trees = ''
 
     @property
     def score_metric_name(self):
@@ -98,7 +97,6 @@ class Trainer(object):
                     X=self._X_forest,
                     y=self._y_forest
                 )
-            self._selected_trees = model.estimators_
         else:
             if type(model) in [OmpForestRegressor, OmpForestBinaryClassifier, OmpForestMulticlassClassifier] and \
                 use_distillation:
@@ -154,14 +152,17 @@ class Trainer(object):
             result = self._base_regression_score_metric(y_true, y_pred)
         return result
 
-    def _evaluate_predictions(self, model, X, aggregation_function):
-        predictions = np.array([tree.predict(X) for tree in self._selected_trees])
+    def _evaluate_predictions(self, X, aggregation_function, selected_trees):
+        predictions = np.array([tree.predict(X) for tree in selected_trees])
 
         predictions = normalize(predictions)
 
         return aggregation_function(np.abs((predictions @ predictions.T - np.eye(len(predictions)))))
 
-    def compute_results(self, model, models_dir):
+    def _compute_forest_strength(self, X, y, metric_function, selected_trees):
+        return np.mean([metric_function(y, tree.predict(X)) for tree in selected_trees])
+
+    def compute_results(self, model, models_dir, subsets_used='train+dev,train+dev'):
         """
         :param model: Object with
         :param models_dir: Where the results will be saved
@@ -177,30 +178,70 @@ class Trainer(object):
 
         if type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor, 
             SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]:
-            self._selected_trees = model.selected_trees
+            selected_trees = model.selected_trees
         elif type(model) in [OmpForestRegressor, OmpForestMulticlassClassifier, OmpForestBinaryClassifier]:
-            self._selected_trees = np.asarray(model.forest)[model._omp.coef_ != 0]
+            selected_trees = np.asarray(model.forest)[model._omp.coef_ != 0]
         elif type(model) in [RandomForestRegressor, RandomForestClassifier]:
-            self._selected_trees = model.estimators_
+            selected_trees = model.estimators_
 
-        if len(self._selected_trees) > 0:
+        if len(selected_trees) > 0:
+            target_selected_tree = int(os.path.split(models_dir)[-1])
+            if target_selected_tree != len(selected_trees):
+                raise ValueError(f'Invalid selected tree number target_selected_tree:{target_selected_tree} - len(selected_trees):{len(selected_trees)}')
             with open(os.path.join(models_dir, 'selected_trees.pickle'), 'wb') as output_file:
-                pickle.dump(self._selected_trees, output_file)
+                pickle.dump(selected_trees, output_file)
+
+        strength_metric = self._regression_score_metric if self._dataset.task == Task.REGRESSION else self._classification_score_metric
+
+        # Reeeally dirty to put that here but otherwise it's not thread safe...
+        if type(model) in [RandomForestRegressor, RandomForestClassifier]:
+            if subsets_used == 'train,dev':
+                X_forest = self._dataset.X_train
+                y_forest = self._dataset.y_train
+            else:
+                X_forest = np.concatenate([self._dataset.X_train, self._dataset.X_dev])
+                y_forest = np.concatenate([self._dataset.y_train, self._dataset.y_dev])
+            X_omp = self._dataset.X_dev
+            y_omp = self._dataset.y_dev
+        elif model.models_parameters.subsets_used == 'train,dev':
+            X_forest = self._dataset.X_train
+            y_forest = self._dataset.y_train
+            X_omp = self._dataset.X_dev
+            y_omp = self._dataset.y_dev
+        elif model.models_parameters.subsets_used == 'train+dev,train+dev':
+            X_forest = np.concatenate([self._dataset.X_train, self._dataset.X_dev])
+            X_omp = X_forest
+            y_forest = np.concatenate([self._dataset.y_train, self._dataset.y_dev])
+            y_omp = y_forest
+        elif model.models_parameters.subsets_used == 'train,train+dev':
+            X_forest = self._dataset.X_train
+            y_forest = self._dataset.y_train
+            X_omp = np.concatenate([self._dataset.X_train, self._dataset.X_dev])
+            y_omp = np.concatenate([self._dataset.y_train, self._dataset.y_dev])
+        else:
+            raise ValueError("Unknown specified subsets_used parameter '{}'".format(model.models_parameters.subsets_used))
 
         results = ModelRawResults(
             model_weights=model_weights,
             training_time=self._end_time - self._begin_time,
             datetime=datetime.datetime.now(),
-            train_score=self.__score_func(model, self._dataset.X_train, self._dataset.y_train),
-            dev_score=self.__score_func(model, self._dataset.X_dev, self._dataset.y_dev),
+            train_score=self.__score_func(model, X_forest, y_forest),
+            dev_score=self.__score_func(model, X_omp, y_omp),
             test_score=self.__score_func(model, self._dataset.X_test, self._dataset.y_test),
-            train_score_base=self.__score_func_base(model, self._dataset.X_train, self._dataset.y_train),
-            dev_score_base=self.__score_func_base(model, self._dataset.X_dev, self._dataset.y_dev),
+            train_score_base=self.__score_func_base(model, X_forest, y_forest),
+            dev_score_base=self.__score_func_base(model, X_omp, y_omp),
             test_score_base=self.__score_func_base(model, self._dataset.X_test, self._dataset.y_test),
             score_metric=self._score_metric_name,
             base_score_metric=self._base_score_metric_name,
-            coherence=self._evaluate_predictions(model, self._dataset.X_train, aggregation_function=np.max),
-            correlation=self._evaluate_predictions(model, self._dataset.X_train, aggregation_function=np.mean)
+            train_coherence=self._evaluate_predictions(X_forest, aggregation_function=np.max, selected_trees=selected_trees),
+            dev_coherence=self._evaluate_predictions(X_omp, aggregation_function=np.max, selected_trees=selected_trees),
+            test_coherence=self._evaluate_predictions(self._dataset.X_test, aggregation_function=np.max, selected_trees=selected_trees),
+            train_correlation=self._evaluate_predictions(X_forest, aggregation_function=np.mean, selected_trees=selected_trees),
+            dev_correlation=self._evaluate_predictions(X_omp, aggregation_function=np.mean, selected_trees=selected_trees),
+            test_correlation=self._evaluate_predictions(self._dataset.X_test, aggregation_function=np.mean, selected_trees=selected_trees),
+            train_strength=self._compute_forest_strength(X_forest, y_forest, strength_metric, selected_trees),
+            dev_strength=self._compute_forest_strength(X_omp, y_omp, strength_metric, selected_trees),
+            test_strength=self._compute_forest_strength(self._dataset.X_test, self._dataset.y_test, strength_metric, selected_trees)
         )
         results.save(models_dir)
         self._logger.info("Base performance on test: {}".format(results.test_score_base))
@@ -212,16 +253,20 @@ class Trainer(object):
         self._logger.info("Base performance on dev: {}".format(results.dev_score_base))
         self._logger.info("Performance on dev: {}".format(results.dev_score))
 
+        self._logger.info(f'test_coherence: {results.test_coherence}')
+        self._logger.info(f'test_correlation: {results.test_correlation}')
+        self._logger.info(f'test_strength: {results.test_strength}')
+
         if type(model) not in [RandomForestRegressor, RandomForestClassifier]:
             results = ModelRawResults(
                 model_weights='',
                 training_time=self._end_time - self._begin_time,
                 datetime=datetime.datetime.now(),
-                train_score=self.__score_func(model, self._dataset.X_train, self._dataset.y_train, False),
-                dev_score=self.__score_func(model, self._dataset.X_dev, self._dataset.y_dev, False),
+                train_score=self.__score_func(model, X_forest, y_forest, False),
+                dev_score=self.__score_func(model, X_omp, y_omp, False),
                 test_score=self.__score_func(model, self._dataset.X_test, self._dataset.y_test, False),
-                train_score_base=self.__score_func_base(model, self._dataset.X_train, self._dataset.y_train),
-                dev_score_base=self.__score_func_base(model, self._dataset.X_dev, self._dataset.y_dev),
+                train_score_base=self.__score_func_base(model, X_forest, y_forest),
+                dev_score_base=self.__score_func_base(model, X_omp, y_omp),
                 test_score_base=self.__score_func_base(model, self._dataset.X_test, self._dataset.y_test),
                 score_metric=self._score_metric_name,
                 base_score_metric=self._base_score_metric_name
diff --git a/code/compute_results.py b/code/compute_results.py
index 0d3850931c83f1a1ef349e519501b4d3955cb684..23e3db3ad7c95e5f5732b4d09e945ce53dfd4467 100644
--- a/code/compute_results.py
+++ b/code/compute_results.py
@@ -2,12 +2,49 @@ from bolsonaro.models.model_raw_results import ModelRawResults
 from bolsonaro.visualization.plotter import Plotter
 from bolsonaro import LOG_PATH
 from bolsonaro.error_handling.logger_factory import LoggerFactory
+from bolsonaro.data.dataset_parameters import DatasetParameters
+from bolsonaro.data.dataset_loader import DatasetLoader
 
 import argparse
 import pathlib
 from dotenv import find_dotenv, load_dotenv
 import os
 import numpy as np
+import pickle
+from tqdm import tqdm
+from scipy.stats import rankdata
+from pyrsa.vis.colors import rdm_colormap
+from pyrsa.rdm.calc import calc_rdm
+from pyrsa.data.dataset import Dataset
+import matplotlib.pyplot as plt
+from sklearn.manifold import MDS
+from sklearn.preprocessing import normalize
+
+
+def vect2triu(dsm_vect, dim=None):
+    if not dim:
+        # sqrt(X²) \simeq sqrt(X²-X) -> sqrt(X²) = ceil(sqrt(X²-X))
+        dim = int(np.ceil(np.sqrt(dsm_vect.shape[1] * 2)))
+    dsm = np.zeros((dim,dim))
+    ind_up = np.triu_indices(dim, 1)
+    dsm[ind_up] = dsm_vect
+    return dsm
+
+def triu2full(dsm_triu):
+    dsm_full = np.copy(dsm_triu)
+    ind_low = np.tril_indices(dsm_full.shape[0], -1)
+    dsm_full[ind_low] = dsm_full.T[ind_low]
+    return dsm_full
+
+def plot_RDM(rdm, file_path, condition_number):
+    rdm = triu2full(vect2triu(rdm, condition_number))
+    fig = plt.figure()
+    cols = rdm_colormap(condition_number)
+    plt.imshow(rdm, cmap=cols)
+    plt.colorbar()
+    plt.savefig(file_path, dpi=200)
+    plt.close()
+
 
 
 def retreive_extracted_forest_sizes_number(models_dir, experiment_id):
@@ -174,7 +211,6 @@ def extract_correlations_across_seeds(models_dir, results_dir, experiment_id):
             extracted_forest_size_path = extracted_forest_sizes_root_path + os.sep + extracted_forest_size
             # Load models/{experiment_id}/seeds/{seed}/extracted_forest_sizes/{extracted_forest_size}/model_raw_results.pickle file
             model_raw_results = ModelRawResults.load(extracted_forest_size_path)
-            # Save the weights
             experiment_correlations[seed].append(model_raw_results.correlation)
 
     return experiment_correlations
@@ -203,11 +239,129 @@ def extract_coherences_across_seeds(models_dir, results_dir, experiment_id):
             extracted_forest_size_path = extracted_forest_sizes_root_path + os.sep + extracted_forest_size
             # Load models/{experiment_id}/seeds/{seed}/extracted_forest_sizes/{extracted_forest_size}/model_raw_results.pickle file
             model_raw_results = ModelRawResults.load(extracted_forest_size_path)
-            # Save the weights
             experiment_coherences[seed].append(model_raw_results.coherence)
 
     return experiment_coherences
 
+def extract_selected_trees_scores_across_seeds(models_dir, results_dir, experiment_id, weighted=False):
+    experiment_id_path = models_dir + os.sep + str(experiment_id) # models/{experiment_id}
+    experiment_seed_root_path = experiment_id_path + os.sep + 'seeds' # models/{experiment_id}/seeds
+    experiment_selected_trees_scores = dict()
+
+    print(f'[extract_selected_trees_scores_across_seeds] experiment_id: {experiment_id}')
+
+    # For each seed results stored in models/{experiment_id}/seeds
+    seeds = os.listdir(experiment_seed_root_path)
+    seeds.sort(key=int)
+    with tqdm(seeds) as seed_bar:
+        for seed in seed_bar:
+            seed_bar.set_description(f'seed: {seed}')
+            experiment_seed_path = experiment_seed_root_path + os.sep + seed # models/{experiment_id}/seeds/{seed}
+            extracted_forest_sizes_root_path = experiment_seed_path + os.sep + 'extracted_forest_sizes' # models/{experiment_id}/seeds/{seed}/forest_size
+
+            dataset_parameters = DatasetParameters.load(experiment_seed_path, experiment_id)
+            dataset = DatasetLoader.load(dataset_parameters)
+
+            # {{seed}:[]}
+            experiment_selected_trees_scores[seed] = list()
+
+            # List the forest sizes in models/{experiment_id}/seeds/{seed}/extracted_forest_sizes
+            extracted_forest_sizes = os.listdir(extracted_forest_sizes_root_path)
+            extracted_forest_sizes = [nb_tree for nb_tree in extracted_forest_sizes if not 'no_weights' in nb_tree]
+            extracted_forest_sizes.sort(key=int)
+            with tqdm(extracted_forest_sizes) as extracted_forest_size_bar:
+                for extracted_forest_size in extracted_forest_size_bar:
+                    # models/{experiment_id}/seeds/{seed}/extracted_forest_sizes/{extracted_forest_size}
+                    extracted_forest_size_path = extracted_forest_sizes_root_path + os.sep + extracted_forest_size
+                    selected_trees = None
+                    with open(os.path.join(extracted_forest_size_path, 'selected_trees.pickle'), 'rb') as file:
+                        selected_trees = pickle.load(file)
+                    selected_trees_test_scores = np.array([tree.score(dataset.X_test, dataset.y_test) for tree in selected_trees])
+
+                    if weighted:
+                        model_raw_results = ModelRawResults.load(extracted_forest_size_path)
+                        weights = model_raw_results.model_weights
+                        if type(weights) != str:
+                            weights = weights[weights != 0]
+                            score = np.mean(np.square(selected_trees_test_scores * weights))
+                        else:
+                            score = np.mean(np.square(selected_trees_test_scores))
+                    else:
+                        score = np.mean(selected_trees_test_scores)
+                    experiment_selected_trees_scores[seed].append(score)
+                    extracted_forest_size_bar.set_description(f'extracted_forest_size: {extracted_forest_size} - test_score: {round(score, 2)}')
+                    extracted_forest_size_bar.update(1)
+            seed_bar.update(1)
+
+    return experiment_selected_trees_scores
+
+def extract_selected_trees_across_seeds(models_dir, results_dir, experiment_id):
+    experiment_id_path = models_dir + os.sep + str(experiment_id) # models/{experiment_id}
+    experiment_seed_root_path = experiment_id_path + os.sep + 'seeds' # models/{experiment_id}/seeds
+    experiment_selected_trees = dict()
+
+    # For each seed results stored in models/{experiment_id}/seeds
+    seeds = os.listdir(experiment_seed_root_path)
+    seeds.sort(key=int)
+    with tqdm(seeds) as seed_bar:
+        for seed in seed_bar:
+            seed_bar.set_description(f'seed: {seed}')
+            experiment_seed_path = experiment_seed_root_path + os.sep + seed # models/{experiment_id}/seeds/{seed}
+            extracted_forest_sizes_root_path = experiment_seed_path + os.sep + 'extracted_forest_sizes' # models/{experiment_id}/seeds/{seed}/forest_size
+
+            dataset_parameters = DatasetParameters.load(experiment_seed_path, experiment_id)
+            dataset = DatasetLoader.load(dataset_parameters)
+
+            # {{seed}:[]}
+            experiment_selected_trees[seed] = list()
+
+            # List the forest sizes in models/{experiment_id}/seeds/{seed}/extracted_forest_sizes
+            extracted_forest_sizes = os.listdir(extracted_forest_sizes_root_path)
+            extracted_forest_sizes = [nb_tree for nb_tree in extracted_forest_sizes if not 'no_weights' in nb_tree ]
+            extracted_forest_sizes.sort(key=int)
+            all_selected_trees_predictions = list()
+            with tqdm(extracted_forest_sizes) as extracted_forest_size_bar:
+                for extracted_forest_size in extracted_forest_size_bar:
+                    # models/{experiment_id}/seeds/{seed}/extracted_forest_sizes/{extracted_forest_size}
+                    extracted_forest_size_path = extracted_forest_sizes_root_path + os.sep + extracted_forest_size
+                    selected_trees = None
+                    with open(os.path.join(extracted_forest_size_path, 'selected_trees.pickle'), 'rb') as file:
+                        selected_trees = pickle.load(file)
+                    #test_score = np.mean([tree.score(dataset.X_test, dataset.y_test) for tree in selected_trees])
+                    #selected_trees_predictions = np.array([tree.score(dataset.X_test, dataset.y_test) for tree in selected_trees])
+                    selected_trees_predictions = [tree.predict(dataset.X_test) for tree in selected_trees]
+                    extracted_forest_size_bar.set_description(f'extracted_forest_size: {extracted_forest_size}')
+                    #experiment_selected_trees[seed].append(test_score)
+                    extracted_forest_size_bar.update(1)
+                    selected_trees_predictions = np.array(selected_trees_predictions)
+                    selected_trees_predictions = normalize(selected_trees_predictions)
+
+                    """mds = MDS(len(selected_trees_predictions))
+                    Y = mds.fit_transform(selected_trees_predictions)
+                    plt.scatter(Y[:, 0], Y[:, 1])
+                    plt.savefig(f'test_mds_{experiment_id}.png')"""
+
+                    if int(extracted_forest_size) <= 267:
+                        forest_RDM = calc_rdm(Dataset(selected_trees_predictions), method='euclidean').get_vectors()
+                        ranked_forest_RDM = np.apply_along_axis(rankdata, 1, forest_RDM.reshape(1, -1))
+
+                        from scipy.cluster import hierarchy
+                        RDM = triu2full(vect2triu(ranked_forest_RDM, int(extracted_forest_size)))
+                        Z = hierarchy.linkage(RDM, 'average')
+                        fig = plt.figure(figsize=(15, 8))
+                        dn = hierarchy.dendrogram(Z)
+                        plt.savefig(f'test_dendrogram_scores_id:{experiment_id}_seed:{seed}_size:{extracted_forest_size}.png')
+                        plt.close()
+
+                        plot_RDM(
+                            rdm=ranked_forest_RDM,
+                            file_path=f'test_scores_ranked_forest_RDM_id:{experiment_id}_seed:{seed}_size:{extracted_forest_size}.png',
+                            condition_number=len(selected_trees_predictions)
+                        )
+            break
+            seed_bar.update(1)
+    return experiment_selected_trees
+
 if __name__ == "__main__":
     # get environment variables in .env
     load_dotenv(find_dotenv('.env'))
@@ -217,6 +371,8 @@ if __name__ == "__main__":
     DEFAULT_PLOT_WEIGHT_DENSITY = False
     DEFAULT_WO_LOSS_PLOTS = False
     DEFAULT_PLOT_PREDS_COHERENCE = False
+    DEFAULT_PLOT_FOREST_STRENGTH = False
+    DEFAULT_COMPUTE_SELECTED_TREES_RDMS = False
 
     parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
     parser.add_argument('--stage', nargs='?', type=int, required=True, help='Specify the stage number among [1, 5].')
@@ -232,6 +388,8 @@ if __name__ == "__main__":
     parser.add_argument('--wo_loss_plots', action='store_true', default=DEFAULT_WO_LOSS_PLOTS, help='Do not compute the loss plots.')
     parser.add_argument('--plot_preds_coherence', action='store_true', default=DEFAULT_PLOT_PREDS_COHERENCE, help='Plot the coherence of the prediction trees.')
     parser.add_argument('--plot_preds_correlation', action='store_true', default=DEFAULT_PLOT_PREDS_COHERENCE, help='Plot the correlation of the prediction trees.')
+    parser.add_argument('--plot_forest_strength', action='store_true', default=DEFAULT_PLOT_FOREST_STRENGTH, help='Plot the strength of the extracted forest.')
+    parser.add_argument('--compute_selected_trees_rdms', action='store_true', default=DEFAULT_COMPUTE_SELECTED_TREES_RDMS, help='Representation similarity analysis of the selected trees')
     args = parser.parse_args()
 
     if args.stage not in list(range(1, 6)):
@@ -534,16 +692,17 @@ if __name__ == "__main__":
         import sys
         sys.exit(0)"""
 
-        #all_labels = ['base', 'random', 'omp', 'omp_wo_weights']
-        all_labels = ['base', 'random', 'omp']
+        all_labels = ['base', 'random', 'omp', 'omp_wo_weights']
+        #all_labels = ['base', 'random', 'omp']
         omp_with_params_test_scores_new = dict()
         filter_num = -1
         """filter_num = 9
         for key, value in omp_with_params_test_scores.items():
             omp_with_params_test_scores_new[key] = value[:filter_num]"""
-        #all_scores = [base_with_params_test_scores, random_with_params_test_scores, omp_with_params_test_scores,
-        #    omp_with_params_without_weights_test_scores]
-        all_scores = [base_with_params_dev_scores, random_with_params_dev_scores, omp_with_params_dev_scores]
+        all_scores = [base_with_params_test_scores, random_with_params_test_scores, omp_with_params_test_scores,
+            omp_with_params_without_weights_test_scores]
+        #all_scores = [base_with_params_dev_scores, random_with_params_dev_scores, omp_with_params_dev_scores,
+        #    omp_with_params_without_weights_dev_scores]
         #all_scores = [base_with_params_train_scores, random_with_params_train_scores, omp_with_params_train_scores,
         #    omp_with_params_without_weights_train_scores]
 
@@ -567,15 +726,15 @@ if __name__ == "__main__":
             current_train_scores, current_dev_scores, current_test_scores, _, _ = extract_scores_across_seeds_and_extracted_forest_sizes(
                 args.models_dir, args.results_dir, current_experiment_id)
             all_labels.append(label)
-            #all_scores.append(current_test_scores)
+            all_scores.append(current_test_scores)
             #all_scores.append(current_train_scores)
-            all_scores.append(current_dev_scores)
+            #all_scores.append(current_dev_scores)
 
-        output_path = os.path.join(args.results_dir, args.dataset_name, 'stage5_new')
+        output_path = os.path.join(args.results_dir, args.dataset_name, 'stage5_test_train,dev')
         pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)
 
         Plotter.plot_stage2_losses(
-            file_path=output_path + os.sep + f"losses_{'-'.join(all_labels)}_dev_clean.png",
+            file_path=output_path + os.sep + f"losses_{'-'.join(all_labels)}_test_train,dev.png",
             all_experiment_scores=all_scores,
             all_labels=all_labels,
             x_value=with_params_extracted_forest_sizes,
@@ -630,7 +789,7 @@ if __name__ == "__main__":
         all_labels = ['random', 'omp', 'kmeans', 'similarity_similarities', 'similarity_predictions', 'ensemble']
         _, _, _, with_params_extracted_forest_sizes, _ = \
             extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, 2)
-        coherence_values = [extract_coherences_across_seeds(args.models_dir, args.results_dir, i) for i in [2, 3, 5, 6, 7, 8]]
+        coherence_values = [extract_coherences_across_seeds(args.models_dir, args.results_dir, i) for i in args.experiment_ids]
         Plotter.plot_stage2_losses(
             file_path=root_output_path + os.sep + f"coherences_{'-'.join(all_labels)}.png",
             all_experiment_scores=coherence_values,
@@ -640,13 +799,14 @@ if __name__ == "__main__":
             ylabel='Coherence',
             title='Coherence values of {}'.format(args.dataset_name))
         logger.info(f'Computing preds coherence plot...')
+
     if args.plot_preds_correlation:
         root_output_path = os.path.join(args.results_dir, args.dataset_name, f'stage5_new')
         pathlib.Path(root_output_path).mkdir(parents=True, exist_ok=True)
-        all_labels = ['random', 'omp', 'kmeans', 'similarity_similarities', 'similarity_predictions', 'ensemble']
+        all_labels = ['none', 'random', 'omp', 'kmeans', 'similarity_similarities', 'similarity_predictions', 'ensemble']
         _, _, _, with_params_extracted_forest_sizes, _ = \
             extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, 2)
-        correlation_values = [extract_correlations_across_seeds(args.models_dir, args.results_dir, i) for i in [2, 3, 5, 6, 7, 8]]
+        correlation_values = [extract_correlations_across_seeds(args.models_dir, args.results_dir, i) for i in args.experiment_ids]
         Plotter.plot_stage2_losses(
             file_path=root_output_path + os.sep + f"correlations_{'-'.join(all_labels)}.png",
             all_experiment_scores=correlation_values,
@@ -657,4 +817,78 @@ if __name__ == "__main__":
             title='correlation values of {}'.format(args.dataset_name))
         logger.info(f'Computing preds correlation plot...')
 
+    if args.plot_forest_strength:
+        root_output_path = os.path.join(args.results_dir, args.dataset_name, f'stage5_strength')
+        pathlib.Path(root_output_path).mkdir(parents=True, exist_ok=True)
+
+        _, _, _, with_params_extracted_forest_sizes, _ = \
+                extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, 2)
+        #all_selected_trees_scores = list()
+        #all_selected_trees_weighted_scores = list()
+        """with tqdm(args.experiment_ids) as experiment_id_bar:
+            for experiment_id in experiment_id_bar:
+                experiment_id_bar.set_description(f'experiment_id: {experiment_id}')
+                selected_trees_scores, selected_trees_weighted_scores = extract_selected_trees_scores_across_seeds(
+                    args.models_dir, args.results_dir, experiment_id)
+                all_selected_trees_scores.append(selected_trees_scores)
+                all_selected_trees_weighted_scores.append(selected_trees_weighted_scores)
+                experiment_id_bar.update(1)"""
+
+        #random_selected_trees_scores = extract_selected_trees_scores_across_seeds(
+        #    args.models_dir, args.results_dir, 2, weighted=True)
+
+        omp_selected_trees_scores = extract_selected_trees_scores_across_seeds(
+            args.models_dir, args.results_dir, 3, weighted=True)
+
+        similarity_similarities_selected_trees_scores = extract_selected_trees_scores_across_seeds(
+            args.models_dir, args.results_dir, 6, weighted=True)
+
+        #similarity_predictions_selected_trees_scores = extract_selected_trees_scores_across_seeds(
+        #    args.models_dir, args.results_dir, 7)
+
+        ensemble_selected_trees_scores = extract_selected_trees_scores_across_seeds(
+            args.models_dir, args.results_dir, 8, weighted=True)
+
+        # kmeans=5
+        # similarity_similarities=6
+        # similarity_predictions=7
+        # ensemble=8
+
+        all_selected_trees_scores = [random_selected_trees_scores, omp_selected_trees_scores, similarity_similarities_selected_trees_scores,
+            ensemble_selected_trees_scores]
+
+        with open('california_housing_forest_strength_scores.pickle', 'wb') as file:
+            pickle.dump(all_selected_trees_scores, file)
+
+        """with open('forest_strength_scores.pickle', 'rb') as file:
+            all_selected_trees_scores = pickle.load(file)"""
+
+        all_labels = ['random', 'omp', 'similarity_similarities', 'ensemble']
+
+        Plotter.plot_stage2_losses(
+            file_path=root_output_path + os.sep + f"forest_strength_{'-'.join(all_labels)}_v2_sota.png",
+            all_experiment_scores=all_selected_trees_scores,
+            all_labels=all_labels,
+            x_value=with_params_extracted_forest_sizes,
+            xlabel='Number of trees extracted',
+            ylabel='Mean of selected tree scores on test set',
+            title='Forest strength of {}'.format(args.dataset_name))
+
+    if args.compute_selected_trees_rdms:
+        root_output_path = os.path.join(args.results_dir, args.dataset_name, f'stage5_strength')
+        pathlib.Path(root_output_path).mkdir(parents=True, exist_ok=True)
+
+        _, _, _, with_params_extracted_forest_sizes, _ = \
+                extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, 2)
+        all_selected_trees_scores = list()
+        with tqdm([2, 3, 8]) as experiment_id_bar:
+            for experiment_id in experiment_id_bar:
+                experiment_id_bar.set_description(f'experiment_id: {experiment_id}')
+                all_selected_trees_scores.append(extract_selected_trees_across_seeds(
+                    args.models_dir, args.results_dir, experiment_id))
+                experiment_id_bar.update(1)
+
+        with open('forest_strength_scores.pickle', 'rb') as file:
+            all_selected_trees_scores = pickle.load(file)
+
     logger.info('Done.')
diff --git a/code/prepare_models.py b/code/prepare_models.py
index 04a2ec3f4446b48268bddaa2e452ef31af9d7eb3..3cd9ea37033063652e15e0e1c84432b831b6562e 100644
--- a/code/prepare_models.py
+++ b/code/prepare_models.py
@@ -8,8 +8,9 @@ from tqdm import tqdm
 if __name__ == "__main__":
     models_source_path = 'models'
     models_destination_path = 'bolsonaro_models_25-03-20'
-    datasets = ['boston', 'diabetes', 'linnerud', 'breast_cancer', 'california_housing', 'diamonds',
-        'steel-plates', 'kr-vs-kp', 'kin8nm', 'spambase', 'gamma', 'lfw_pairs']
+    #datasets = ['boston', 'diabetes', 'linnerud', 'breast_cancer', 'california_housing', 'diamonds',
+    #    'steel-plates', 'kr-vs-kp', 'kin8nm', 'spambase', 'gamma', 'lfw_pairs']
+    datasets = ['kin8nm']
 
     pathlib.Path(models_destination_path).mkdir(parents=True, exist_ok=True)
 
diff --git a/code/train.py b/code/train.py
index 189aac1c7c9ccf68b2da57af7a0eebc3680d365b..10dbf7354837cab803202a8307c671f0def0f274 100644
--- a/code/train.py
+++ b/code/train.py
@@ -66,11 +66,11 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb
             extraction_strategy=parameters['extraction_strategy']
         )
         pretrained_estimator = ModelFactory.build(dataset.task, pretrained_model_parameters)
-        pretraned_trainer = Trainer(dataset)
-        pretraned_trainer.init(pretrained_estimator, subsets_used=parameters['subsets_used'])
+        pretrained_trainer = Trainer(dataset)
+        pretrained_trainer.init(pretrained_estimator, subsets_used=parameters['subsets_used'])
         pretrained_estimator.fit(
-            X=pretraned_trainer._X_forest,
-            y=pretraned_trainer._y_forest
+            X=pretrained_trainer._X_forest,
+            y=pretrained_trainer._y_forest
         )
     else:
         pretrained_estimator = None