diff --git a/code/bolsonaro/data/dataset_loader.py b/code/bolsonaro/data/dataset_loader.py
index 0bebeaaf6c9f0dac2258a384a8b2bcff50c4c5b3..ed438d04f032d965513f028baa06a281d3f4ff4a 100644
--- a/code/bolsonaro/data/dataset_loader.py
+++ b/code/bolsonaro/data/dataset_loader.py
@@ -81,7 +81,7 @@ class DatasetLoader(object):
         elif name == 'lfw_pairs':
             dataset = fetch_lfw_pairs()
             X, y = dataset.data, dataset.target
-            task = Task.MULTICLASSIFICATION
+            task = Task.BINARYCLASSIFICATION
         elif name == 'covtype':
             X, y = fetch_covtype(random_state=dataset_parameters.random_state, shuffle=True, return_X_y=True)
             task = Task.MULTICLASSIFICATION
diff --git a/code/bolsonaro/models/model_factory.py b/code/bolsonaro/models/model_factory.py
index 07799ceb966e2a40b10e98fdd134fe458674cf8b..d11af3b09b2538557f140d885c5f88ee1c8c97e7 100644
--- a/code/bolsonaro/models/model_factory.py
+++ b/code/bolsonaro/models/model_factory.py
@@ -29,7 +29,7 @@ class ModelFactory(object):
                     random_state=model_parameters.seed)
             elif model_parameters.extraction_strategy == 'kmeans':
                 return KMeansForestClassifier(model_parameters)
-            elif model_parameters.extraction_strategy == 'similarity':
+            elif model_parameters.extraction_strategy in ['similarity_similarities', 'similarity_predictions']:
                 return SimilarityForestClassifier(model_parameters)
             else:
                 raise ValueError('Invalid extraction strategy')
@@ -39,7 +39,7 @@ class ModelFactory(object):
             elif model_parameters.extraction_strategy == 'random':
                 return RandomForestRegressor(**model_parameters.hyperparameters,
                     random_state=model_parameters.seed)
-            elif model_parameters.extraction_strategy == 'similarity':
+            elif model_parameters.extraction_strategy in ['similarity_similarities', 'similarity_predictions']:
                 return SimilarityForestRegressor(model_parameters)
             elif model_parameters.extraction_strategy == 'kmeans':
                 return KMeansForestRegressor(model_parameters)
diff --git a/code/compute_results.py b/code/compute_results.py
index d77779e82e295b5e76c0347551c20b8ef258a546..c534bb0c354be48f148ad7e81d3252f20116603e 100644
--- a/code/compute_results.py
+++ b/code/compute_results.py
@@ -7,6 +7,7 @@ import argparse
 import pathlib
 from dotenv import find_dotenv, load_dotenv
 import os
+import numpy as np
 
 
 def retreive_extracted_forest_sizes_number(models_dir, experiment_id):
@@ -17,7 +18,7 @@ def retreive_extracted_forest_sizes_number(models_dir, experiment_id):
     extracted_forest_sizes_root_path = experiment_seed_path + os.sep + 'extracted_forest_sizes'
     return len(os.listdir(extracted_forest_sizes_root_path))
 
-def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_dir, experiment_id, weights=True):
+def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_dir, experiment_id, weights=True, extracted_forest_sizes=list()):
     experiment_id_path = models_dir + os.sep + str(experiment_id) # models/{experiment_id}
     experiment_seed_root_path = experiment_id_path + os.sep + 'seeds' # models/{experiment_id}/seeds
 
@@ -45,10 +46,11 @@ def extract_scores_across_seeds_and_extracted_forest_sizes(models_dir, results_d
         experiment_dev_scores[seed] = list()
         experiment_test_scores[seed] = list()
 
-        # List the forest sizes in models/{experiment_id}/seeds/{seed}/extracted_forest_sizes
-        extracted_forest_sizes = os.listdir(extracted_forest_sizes_root_path)
-        extracted_forest_sizes = [nb_tree for nb_tree in extracted_forest_sizes if not 'no_weights' in nb_tree ]
-        extracted_forest_sizes.sort(key=int)
+        if len(extracted_forest_sizes) == 0:
+            # List the forest sizes in models/{experiment_id}/seeds/{seed}/extracted_forest_sizes
+            extracted_forest_sizes = os.listdir(extracted_forest_sizes_root_path)
+            extracted_forest_sizes = [nb_tree for nb_tree in extracted_forest_sizes if not 'no_weights' in nb_tree ]
+            extracted_forest_sizes.sort(key=int)
         all_extracted_forest_sizes.append(list(map(int, extracted_forest_sizes)))
         for extracted_forest_size in extracted_forest_sizes:
             # models/{experiment_id}/seeds/{seed}/extracted_forest_sizes/{extracted_forest_size}
@@ -437,6 +439,15 @@ if __name__ == "__main__":
         all_labels = list()
         all_scores = list()
 
+        """extracted_forest_sizes = np.unique(np.around(1000 *
+            np.linspace(0, 1.0,
+            30 + 1,
+            endpoint=True)[1:]).astype(np.int)).tolist()"""
+
+        extracted_forest_sizes = [4, 7, 11, 14, 18, 22, 25, 29, 32, 36, 40, 43, 47, 50, 54, 58, 61, 65, 68, 72, 76, 79, 83, 86, 90, 94, 97, 101, 104, 108]
+
+        extracted_forest_sizes = [str(forest_size) for forest_size in extracted_forest_sizes]
+
         # base_with_params
         logger.info('Loading base_with_params experiment scores...')
         base_with_params_train_scores, base_with_params_dev_scores, base_with_params_test_scores, \
@@ -447,21 +458,23 @@ if __name__ == "__main__":
         logger.info('Loading random_with_params experiment scores...')
         random_with_params_train_scores, random_with_params_dev_scores, random_with_params_test_scores, \
             with_params_extracted_forest_sizes, random_with_params_experiment_score_metric = \
-            extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, int(args.experiment_ids[1]))
+            extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, int(args.experiment_ids[1]),
+            extracted_forest_sizes=extracted_forest_sizes)
         # omp_with_params
         logger.info('Loading omp_with_params experiment scores...')
         omp_with_params_train_scores, omp_with_params_dev_scores, omp_with_params_test_scores, _, \
             omp_with_params_experiment_score_metric = extract_scores_across_seeds_and_extracted_forest_sizes(
-                args.models_dir, args.results_dir, int(args.experiment_ids[2]))
+                args.models_dir, args.results_dir, int(args.experiment_ids[2]), extracted_forest_sizes=extracted_forest_sizes)
         #omp_with_params_without_weights
         logger.info('Loading omp_with_params without weights experiment scores...')
         omp_with_params_without_weights_train_scores, omp_with_params_without_weights_dev_scores, omp_with_params_without_weights_test_scores, _, \
             omp_with_params_experiment_score_metric = extract_scores_across_seeds_and_extracted_forest_sizes(
-                args.models_dir, args.results_dir, int(args.experiment_ids[2]), weights=False)
+                args.models_dir, args.results_dir, int(args.experiment_ids[2]), weights=False, extracted_forest_sizes=extracted_forest_sizes)
 
-        all_labels = ['base', 'random', 'omp', 'omp_without_weights']
-        all_scores = [base_with_params_test_scores, random_with_params_test_scores, omp_with_params_test_scores,
-            omp_with_params_without_weights_test_scores]
+        all_labels = ['base', 'random', 'omp']
+        all_scores = [base_with_params_test_scores, random_with_params_test_scores, omp_with_params_test_scores]
+        #all_scores = [base_with_params_train_scores, random_with_params_train_scores, omp_with_params_train_scores,
+        #    omp_with_params_without_weights_train_scores]
 
         for i in range(3, len(args.experiment_ids)):
             if 'kmeans' in args.experiment_ids[i]:
@@ -476,16 +489,17 @@ if __name__ == "__main__":
 
             logger.info(f'Loading {label} experiment scores...')
             current_experiment_id = int(args.experiment_ids[i].split('=')[1])
-            _, _, current_test_scores, _, _ = extract_scores_across_seeds_and_extracted_forest_sizes(
+            current_train_scores, _, current_test_scores, _, _ = extract_scores_across_seeds_and_extracted_forest_sizes(
                 args.models_dir, args.results_dir, current_experiment_id)
             all_labels.append(label)
             all_scores.append(current_test_scores)
+            #all_scores.append(current_train_scores)
 
         output_path = os.path.join(args.results_dir, args.dataset_name, 'stage5')
         pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)
 
         Plotter.plot_stage2_losses(
-            file_path=output_path + os.sep + f"losses_{'-'.join(all_labels)}.png",
+            file_path=output_path + os.sep + f"losses_{'-'.join(all_labels)}_test.png",
             all_experiment_scores=all_scores,
             all_labels=all_labels,
             x_value=with_params_extracted_forest_sizes,
diff --git a/code/train.py b/code/train.py
index 95498cdf03a894ca8c8cf91d6702acc6aef1a799..e7a319de1b1dcf87c51bb96b537d4d2df80499fb 100644
--- a/code/train.py
+++ b/code/train.py
@@ -97,11 +97,11 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb
         if os.path.isdir(sub_models_dir):
             sub_models_dir_files = os.listdir(sub_models_dir)
             for file_name in sub_models_dir_files:
-                if '.pickle' != os.path.splitext(file_name)[1]:
-                    continue
-                else:
+                if file_name == 'model_raw_results.pickle':
                     already_exists = os.path.getsize(os.path.join(sub_models_dir, file_name)) > 0
                     break
+                else:
+                    continue
         if already_exists:
             logger.info('Base forest result already exists. Skipping...')
         else:
@@ -140,11 +140,11 @@ def extracted_forest_size_job(extracted_forest_size_job_pb, extracted_forest_siz
     if os.path.isdir(sub_models_dir):
         sub_models_dir_files = os.listdir(sub_models_dir)
         for file_name in sub_models_dir_files:
-            if '.pickle' != os.path.splitext(file_name)[1]:
-                continue
-            else:
+            if file_name == 'model_raw_results.pickle':
                 already_exists = os.path.getsize(os.path.join(sub_models_dir, file_name)) > 0
                 break
+            else:
+                continue
     if already_exists:
         logger.info(f'Extracted forest {extracted_forest_size} result already exists. Skipping...')
         return
@@ -235,7 +235,7 @@ if __name__ == "__main__":
     parser.add_argument('--skip_best_hyperparams', action='store_true', default=DEFAULT_SKIP_BEST_HYPERPARAMS, help='Do not use the best hyperparameters if there exist.')
     parser.add_argument('--save_experiment_configuration', nargs='+', default=None, help='Save the experiment parameters specified in the command line in a file. Args: {{stage_num}} {{name}}')
     parser.add_argument('--job_number', nargs='?', type=int, default=DEFAULT_JOB_NUMBER, help='Specify the number of job used during the parallelisation across seeds.')
-    parser.add_argument('--extraction_strategy', nargs='?', type=str, default=DEFAULT_EXTRACTION_STRATEGY, help='Specify the strategy to apply to extract the trees from the forest. Either omp, random, none, similarity, kmeans, ensemble.')
+    parser.add_argument('--extraction_strategy', nargs='?', type=str, default=DEFAULT_EXTRACTION_STRATEGY, help='Specify the strategy to apply to extract the trees from the forest. Either omp, random, none, similarity_similarities, similarity_predictions, kmeans, ensemble.')
     parser.add_argument('--overwrite', action='store_true', default=DEFAULT_OVERWRITE, help='Overwrite the experiment id')
     args = parser.parse_args()
 
@@ -246,7 +246,7 @@ if __name__ == "__main__":
     else:
         parameters = args.__dict__
 
-    if parameters['extraction_strategy'] not in ['omp', 'random', 'none', 'similarity', 'kmeans', 'ensemble']:
+    if parameters['extraction_strategy'] not in ['omp', 'random', 'none', 'similarity_similarities', 'similarity_predictions', 'kmeans', 'ensemble']:
         raise ValueError('Specified extraction strategy {} is not supported.'.format(parameters.extraction_strategy))
 
     pathlib.Path(parameters['models_dir']).mkdir(parents=True, exist_ok=True)