From 87fb436ccf50ef1b76d5e74a07d85422eb000349 Mon Sep 17 00:00:00 2001
From: Charly Lamothe <charly.lamothe@univ-amu.fr>
Date: Wed, 25 Mar 2020 00:25:49 +0100
Subject: [PATCH] Fix flw_pairs loading. Prepare all new exps:
 omp_distillation, preds coherence, preds correlation, normalize_D when OMP,
 n_jobs=-1 in SOTA. In exps script, test both train+dev,train+dev and
 train,dev

---
 code/bolsonaro/data/dataset_loader.py         |  4 +-
 .../ensemble_selection_forest_regressor.py    | 33 +-----------
 code/bolsonaro/models/forest_pruning_sota.py  |  2 -
 .../models/kmeans_forest_regressor.py         |  4 +-
 code/bolsonaro/models/model_factory.py        |  8 +--
 code/bolsonaro/models/model_raw_results.py    | 13 ++++-
 code/bolsonaro/models/omp_forest.py           |  5 +-
 .../models/similarity_forest_regressor.py     |  4 +-
 code/bolsonaro/trainer.py                     | 50 +++++++++++++----
 code/compute_results.py                       | 54 ++++++++++++++++++-
 code/train.py                                 | 28 +++++-----
 scripts/run_stage5_experiments.sh             | 29 ++++++----
 12 files changed, 151 insertions(+), 83 deletions(-)

diff --git a/code/bolsonaro/data/dataset_loader.py b/code/bolsonaro/data/dataset_loader.py
index ed438d0..5e38b1e 100644
--- a/code/bolsonaro/data/dataset_loader.py
+++ b/code/bolsonaro/data/dataset_loader.py
@@ -1,7 +1,7 @@
 from bolsonaro.data.dataset import Dataset
 from bolsonaro.data.dataset_parameters import DatasetParameters
 from bolsonaro.data.task import Task
-from bolsonaro.utils import change_binary_func_load, change_binary_func_openml
+from bolsonaro.utils import change_binary_func_load, change_binary_func_openml, binarize_class_data
 
 from sklearn.datasets import load_boston, load_iris, load_diabetes, \
     load_digits, load_linnerud, load_wine, load_breast_cancer
@@ -81,6 +81,8 @@ class DatasetLoader(object):
         elif name == 'lfw_pairs':
             dataset = fetch_lfw_pairs()
             X, y = dataset.data, dataset.target
+            possible_classes = sorted(set(y))
+            y = binarize_class_data(y, possible_classes[-1])
             task = Task.BINARYCLASSIFICATION
         elif name == 'covtype':
             X, y = fetch_covtype(random_state=dataset_parameters.random_state, shuffle=True, return_X_y=True)
diff --git a/code/bolsonaro/models/ensemble_selection_forest_regressor.py b/code/bolsonaro/models/ensemble_selection_forest_regressor.py
index fb399d1..aa649fa 100644
--- a/code/bolsonaro/models/ensemble_selection_forest_regressor.py
+++ b/code/bolsonaro/models/ensemble_selection_forest_regressor.py
@@ -55,7 +55,7 @@ class EnsembleSelectionForestClassifier(EnsembleSelectionForest, metaclass=ABCMe
     @staticmethod
     def init_estimator(model_parameters):
         return RandomForestClassifier(**model_parameters.hyperparameters,
-                                    random_state=model_parameters.seed, n_jobs=2)
+                                    random_state=model_parameters.seed, n_jobs=-1)
 
     def _aggregate(self, predictions):
         return aggregation_classification(predictions)
@@ -90,7 +90,7 @@ class EnsembleSelectionForestRegressor(EnsembleSelectionForest, metaclass=ABCMet
     @staticmethod
     def init_estimator(model_parameters):
         return RandomForestRegressor(**model_parameters.hyperparameters,
-                              random_state=model_parameters.seed, n_jobs=2)
+                              random_state=model_parameters.seed, n_jobs=-1)
 
     def _aggregate(self, predictions):
         return aggregation_regression(predictions)
@@ -108,32 +108,3 @@ class EnsembleSelectionForestRegressor(EnsembleSelectionForest, metaclass=ABCMet
     @staticmethod
     def _worse_score_idx(array):
         return np.argmax(array)
-
-
-
-    # @staticmethod
-    # def generate_library(X_train, y_train, random_state=None):
-    #     criterion_arr = ["mse"]#, "friedman_mse", "mae"]
-    #     splitter_arr = ["best"]#, "random"]
-    #     depth_arr = [i for i in range(5, 20, 1)]
-    #     min_samples_split_arr = [i for i in range(2, 20, 1)]
-    #     min_samples_leaf_arr = [i for i in range(2, 20, 1)]
-    #     max_features_arr = ["sqrt"]#["auto", "sqrt", "log2"]
-    #
-    #     library = list()
-    #     with tqdm(total=len(criterion_arr) * len(splitter_arr) * \
-    #         len(depth_arr) * len(min_samples_split_arr) * len(min_samples_leaf_arr) * \
-    #         len(max_features_arr)) as bar:
-    #         bar.set_description('Generating library')
-    #         for criterion in criterion_arr:
-    #             for splitter in splitter_arr:
-    #                 for depth in depth_arr:
-    #                     for min_samples_split in min_samples_split_arr:
-    #                         for min_samples_leaf in min_samples_leaf_arr:
-    #                             for max_features in max_features_arr:
-    #                                 t = DecisionTreeRegressor(criterion=criterion, splitter=splitter, max_depth=depth, min_samples_split=min_samples_split,
-    #                                     min_samples_leaf=min_samples_leaf, max_features=max_features, random_state=random_state)
-    #                                 t.fit(X_train, y_train)
-    #                                 library.append(t)
-    #                                 bar.update(1)
-    #     return library
diff --git a/code/bolsonaro/models/forest_pruning_sota.py b/code/bolsonaro/models/forest_pruning_sota.py
index 80d6fe7..79bc006 100644
--- a/code/bolsonaro/models/forest_pruning_sota.py
+++ b/code/bolsonaro/models/forest_pruning_sota.py
@@ -1,5 +1,3 @@
-import time
-
 from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
 from sklearn.metrics import mean_squared_error
 from sklearn.base import BaseEstimator
diff --git a/code/bolsonaro/models/kmeans_forest_regressor.py b/code/bolsonaro/models/kmeans_forest_regressor.py
index dc47884..2c9be47 100644
--- a/code/bolsonaro/models/kmeans_forest_regressor.py
+++ b/code/bolsonaro/models/kmeans_forest_regressor.py
@@ -49,7 +49,7 @@ class KMeansForestRegressor(KmeansForest, metaclass=ABCMeta):
     @staticmethod
     def init_estimator(model_parameters):
         return RandomForestRegressor(**model_parameters.hyperparameters,
-                              random_state=model_parameters.seed, n_jobs=2)
+                              random_state=model_parameters.seed, n_jobs=-1)
 
     def _aggregate(self, predictions):
         return aggregation_regression(predictions)
@@ -70,7 +70,7 @@ class KMeansForestClassifier(KmeansForest, metaclass=ABCMeta):
     @staticmethod
     def init_estimator(model_parameters):
         return RandomForestClassifier(**model_parameters.hyperparameters,
-                                                random_state=model_parameters.seed, n_jobs=2)
+                                                random_state=model_parameters.seed, n_jobs=-1)
 
     def _aggregate(self, predictions):
         return aggregation_classification(predictions)
diff --git a/code/bolsonaro/models/model_factory.py b/code/bolsonaro/models/model_factory.py
index 4a70a1c..6f18362 100644
--- a/code/bolsonaro/models/model_factory.py
+++ b/code/bolsonaro/models/model_factory.py
@@ -14,12 +14,12 @@ import pickle
 class ModelFactory(object):
 
     @staticmethod
-    def build(task, model_parameters, library=None):
+    def build(task, model_parameters):
         if task not in [Task.BINARYCLASSIFICATION, Task.REGRESSION, Task.MULTICLASSIFICATION]:
             raise ValueError("Unsupported task '{}'".format(task))
 
         if task == Task.BINARYCLASSIFICATION:
-            if model_parameters.extraction_strategy == 'omp':
+            if model_parameters.extraction_strategy in ['omp', 'omp_distillation']:
                 return OmpForestBinaryClassifier(model_parameters)
             elif model_parameters.extraction_strategy == 'random':
                 return RandomForestClassifier(**model_parameters.hyperparameters,
@@ -36,7 +36,7 @@ class ModelFactory(object):
             else:
                 raise ValueError('Invalid extraction strategy')
         elif task == Task.REGRESSION:
-            if model_parameters.extraction_strategy == 'omp':
+            if model_parameters.extraction_strategy in ['omp', 'omp_distillation']:
                 return OmpForestRegressor(model_parameters)
             elif model_parameters.extraction_strategy == 'random':
                 return RandomForestRegressor(**model_parameters.hyperparameters,
@@ -53,7 +53,7 @@ class ModelFactory(object):
             else:
                 raise ValueError('Invalid extraction strategy')
         elif task == Task.MULTICLASSIFICATION:
-            if model_parameters.extraction_strategy == 'omp':
+            if model_parameters.extraction_strategy in ['omp', 'omp_distillation']:
                 return OmpForestMulticlassClassifier(model_parameters)
             elif model_parameters.extraction_strategy == 'random':
                 return RandomForestClassifier(**model_parameters.hyperparameters,
diff --git a/code/bolsonaro/models/model_raw_results.py b/code/bolsonaro/models/model_raw_results.py
index fbb80a5..7d61065 100644
--- a/code/bolsonaro/models/model_raw_results.py
+++ b/code/bolsonaro/models/model_raw_results.py
@@ -9,7 +9,8 @@ class ModelRawResults(object):
     def __init__(self, model_weights, training_time,
         datetime, train_score, dev_score, test_score,
         train_score_base, dev_score_base,
-        test_score_base, score_metric, base_score_metric):
+        test_score_base, score_metric, base_score_metric,
+        coherence='', correlation=''):
 
         self._model_weights = model_weights
         self._training_time = training_time
@@ -22,6 +23,8 @@ class ModelRawResults(object):
         self._test_score_base = test_score_base
         self._score_metric = score_metric
         self._base_score_metric = base_score_metric
+        self._coherence = coherence
+        self._correlation = correlation
 
     @property
     def model_weights(self):
@@ -67,6 +70,14 @@ class ModelRawResults(object):
     def base_score_metric(self):
         return self._base_score_metric
 
+    @property
+    def coherence(self):
+        return self._coherence
+
+    @property
+    def correlation(self):
+        return self._correlation
+
     def save(self, models_dir):
         if not os.path.exists(models_dir):
             os.mkdir(models_dir)
diff --git a/code/bolsonaro/models/omp_forest.py b/code/bolsonaro/models/omp_forest.py
index 90a6fc3..e4830f0 100644
--- a/code/bolsonaro/models/omp_forest.py
+++ b/code/bolsonaro/models/omp_forest.py
@@ -36,11 +36,12 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta):
         return self._base_forest_estimator.estimators_
 
     # sklearn baseestimator api methods
-    def fit(self, X_forest, y_forest, X_omp, y_omp):
+    def fit(self, X_forest, y_forest, X_omp, y_omp, use_distillation=False):
         # print(y_forest.shape)
         # print(set([type(y) for y in y_forest]))
         self._base_forest_estimator.fit(X_forest, y_forest)
-        self._extract_subforest(X_omp, y_omp) # type: OrthogonalMatchingPursuit
+        self._extract_subforest(X_omp,
+            self.predict_base_estimator(X_forest) if use_distillation else y_omp) # type: OrthogonalMatchingPursuit
         return self
 
     def _extract_subforest(self, X, y):
diff --git a/code/bolsonaro/models/similarity_forest_regressor.py b/code/bolsonaro/models/similarity_forest_regressor.py
index 95a035d..edbb8ad 100644
--- a/code/bolsonaro/models/similarity_forest_regressor.py
+++ b/code/bolsonaro/models/similarity_forest_regressor.py
@@ -87,7 +87,7 @@ class SimilarityForestRegressor(SimilarityForest, metaclass=ABCMeta):
     @staticmethod
     def init_estimator(model_parameters):
         return RandomForestRegressor(**model_parameters.hyperparameters,
-                              random_state=model_parameters.seed, n_jobs=2)
+                              random_state=model_parameters.seed, n_jobs=-1)
 
     def _aggregate(self, predictions):
         return aggregation_regression(predictions)
@@ -111,7 +111,7 @@ class SimilarityForestClassifier(SimilarityForest, metaclass=ABCMeta):
     @staticmethod
     def init_estimator(model_parameters):
         return RandomForestClassifier(**model_parameters.hyperparameters,
-                                    random_state=model_parameters.seed, n_jobs=2)
+                                    random_state=model_parameters.seed, n_jobs=-1)
 
     def _aggregate(self, predictions):
         return aggregation_classification(predictions)
diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py
index 8d82b3d..21b7ab1 100644
--- a/code/bolsonaro/trainer.py
+++ b/code/bolsonaro/trainer.py
@@ -10,6 +10,7 @@ from . import LOG_PATH
 
 from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
 from sklearn.metrics import mean_squared_error, accuracy_score
+from sklearn.preprocessing import normalize
 import time
 import datetime
 import numpy as np
@@ -77,7 +78,7 @@ class Trainer(object):
         else:
             raise ValueError("Unknown specified subsets_used parameter '{}'".format(model.models_parameters.subsets_used))
 
-    def train(self, model, extracted_forest_size=None):
+    def train(self, model, extracted_forest_size=None, seed=None, use_distillation=False):
         """
         :param model: An instance of either RandomForestRegressor, RandomForestClassifier, OmpForestRegressor,
             OmpForestBinaryClassifier, OmpForestMulticlassClassifier.
@@ -88,6 +89,7 @@ class Trainer(object):
         if type(model) in [RandomForestRegressor, RandomForestClassifier]:
             if extracted_forest_size is not None:
                 estimators_index = np.arange(len(model.estimators_))
+                np.random.seed(seed)
                 np.random.shuffle(estimators_index)
                 choosen_estimators = estimators_index[:extracted_forest_size]
                 model.estimators_ = np.array(model.estimators_)[choosen_estimators]
@@ -98,12 +100,22 @@ class Trainer(object):
                 )
             self._selected_trees = model.estimators_
         else:
-            model.fit(
-                self._X_forest,
-                self._y_forest,
-                self._X_omp,
-                self._y_omp
-            )
+            if type(model) in [OmpForestRegressor, OmpForestBinaryClassifier, OmpForestMulticlassClassifier] and \
+                use_distillation:
+                model.fit(
+                    self._X_forest, # X_train or X_train+X_dev
+                    self._y_forest,
+                    self._X_omp, # X_train+X_dev or X_dev
+                    self._y_omp,
+                    use_distillation=use_distillation
+                )
+            else:
+                model.fit(
+                    self._X_forest, # X_train or X_train+X_dev
+                    self._y_forest,
+                    self._X_omp, # X_train+X_dev or X_dev
+                    self._y_omp
+                )
         self._end_time = time.time()
 
     def __score_func(self, model, X, y_true, weights=True):
@@ -141,6 +153,20 @@ class Trainer(object):
             result = self._base_regression_score_metric(y_true, y_pred)
         return result
 
+    def _evaluate_predictions(self, model, X, aggregation_function):
+        if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor,
+            OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
+            estimators = model.forest
+            estimators = np.asarray(estimators)[model._omp.coef_ != 0]
+        elif type(model) in [RandomForestRegressor, RandomForestClassifier]:
+            estimators = model.estimators_
+
+        predictions = np.array([tree.predict(X) for tree in estimators])
+
+        predictions = normalize(predictions)
+
+        return aggregation_function(np.abs((predictions @ predictions.T - np.eye(len(predictions)))))
+
     def compute_results(self, model, models_dir):
         """
         :param model: Object with
@@ -173,7 +199,9 @@ class Trainer(object):
             dev_score_base=self.__score_func_base(model, self._dataset.X_dev, self._dataset.y_dev),
             test_score_base=self.__score_func_base(model, self._dataset.X_test, self._dataset.y_test),
             score_metric=self._score_metric_name,
-            base_score_metric=self._base_score_metric_name
+            base_score_metric=self._base_score_metric_name,
+            coherence=self._evaluate_predictions(model, self._dataset.X_train, aggregation_function=np.max),
+            correlation=self._evaluate_predictions(model, self._dataset.X_train, aggregation_function=np.mean)
         )
         results.save(models_dir)
         self._logger.info("Base performance on test: {}".format(results.test_score_base))
@@ -201,10 +229,10 @@ class Trainer(object):
             )
             results.save(models_dir+'_no_weights')
             self._logger.info("Base performance on test without weights: {}".format(results.test_score_base))
-            self._logger.info("Performance on test: {}".format(results.test_score))
+            self._logger.info("Performance on test without weights: {}".format(results.test_score))
 
             self._logger.info("Base performance on train without weights: {}".format(results.train_score_base))
-            self._logger.info("Performance on train: {}".format(results.train_score))
+            self._logger.info("Performance on train without weights: {}".format(results.train_score))
 
             self._logger.info("Base performance on dev without weights: {}".format(results.dev_score_base))
-            self._logger.info("Performance on dev: {}".format(results.dev_score))
+            self._logger.info("Performance on dev without weights: {}".format(results.dev_score))
diff --git a/code/compute_results.py b/code/compute_results.py
index c534bb0..ab90b85 100644
--- a/code/compute_results.py
+++ b/code/compute_results.py
@@ -150,6 +150,35 @@ def extract_weights_across_seeds(models_dir, results_dir, experiment_id):
 
     return experiment_weights
 
+def extract_coherences_across_seeds(models_dir, results_dir, experiment_id):
+    experiment_id_path = models_dir + os.sep + str(experiment_id) # models/{experiment_id}
+    experiment_seed_root_path = experiment_id_path + os.sep + 'seeds' # models/{experiment_id}/seeds
+    experiment_coherences = dict()
+
+    # For each seed results stored in models/{experiment_id}/seeds
+    seeds = os.listdir(experiment_seed_root_path)
+    seeds.sort(key=int)
+    for seed in seeds:
+        experiment_seed_path = experiment_seed_root_path + os.sep + seed # models/{experiment_id}/seeds/{seed}
+        extracted_forest_sizes_root_path = experiment_seed_path + os.sep + 'extracted_forest_sizes' # models/{experiment_id}/seeds/{seed}/forest_size
+
+        # {{seed}:[]}
+        experiment_coherences[seed] = list()
+
+        # List the forest sizes in models/{experiment_id}/seeds/{seed}/extracted_forest_sizes
+        extracted_forest_sizes = os.listdir(extracted_forest_sizes_root_path)
+        extracted_forest_sizes = [nb_tree for nb_tree in extracted_forest_sizes if not 'no_weights' in nb_tree ]
+        extracted_forest_sizes.sort(key=int)
+        for extracted_forest_size in extracted_forest_sizes:
+            # models/{experiment_id}/seeds/{seed}/extracted_forest_sizes/{extracted_forest_size}
+            extracted_forest_size_path = extracted_forest_sizes_root_path + os.sep + extracted_forest_size
+            # Load models/{experiment_id}/seeds/{seed}/extracted_forest_sizes/{extracted_forest_size}/model_raw_results.pickle file
+            model_raw_results = ModelRawResults.load(extracted_forest_size_path)
+            # Save the weights
+            experiment_coherences[seed].append(model_raw_results.coherence)
+
+    return experiment_coherences
+
 
 if __name__ == "__main__":
     # get environment variables in .env
@@ -507,7 +536,7 @@ if __name__ == "__main__":
             ylabel=base_with_params_experiment_score_metric,
             title='Loss values of {}\nusing best params of previous stages'.format(args.dataset_name))
 
-    if args.plot_weight_density:
+    """if args.plot_weight_density:
         root_output_path = os.path.join(args.results_dir, args.dataset_name, f'stage{args.stage}')
 
         if args.stage == 1:
@@ -542,6 +571,27 @@ if __name__ == "__main__":
         for (experiment_label, experiment_id) in omp_experiment_ids:
             logger.info(f'Computing weight density plot for experiment {experiment_label}...')
             experiment_weights = extract_weights_across_seeds(args.models_dir, args.results_dir, experiment_id)
-            Plotter.weight_density(experiment_weights, os.path.join(root_output_path, f'weight_density_{experiment_label}.png'))
+            Plotter.weight_density(experiment_weights, os.path.join(root_output_path, f'weight_density_{experiment_label}.png'))"""
+
+    if args.plot_weight_density:
+        logger.info(f'Computing weight density plot for experiment {experiment_label}...')
+        experiment_weights = extract_weights_across_seeds(args.models_dir, args.results_dir, experiment_id)
+        Plotter.weight_density(experiment_weights, os.path.join(root_output_path, f'weight_density_{experiment_label}.png'))
+    if args.plot_preds_coherence:
+        root_output_path = os.path.join(args.results_dir, args.dataset_name, f'stage5')
+        all_labels = ['random', 'omp', 'omp_normalize_D']
+        random_with_params_train_scores, random_with_params_dev_scores, random_with_params_test_scores, \
+            with_params_extracted_forest_sizes, random_with_params_experiment_score_metric = \
+            extract_scores_across_seeds_and_extracted_forest_sizes(args.models_dir, args.results_dir, 2)
+        coherence_values = [extract_coherences_across_seeds(args.models_dir, args.results_dir, i) for i in [2, 3, 4]]
+        Plotter.plot_stage2_losses(
+            file_path=root_output_path + os.sep + f"coherences_{'-'.join(all_labels)}_30_all.png",
+            all_experiment_scores=coherence_values,
+            all_labels=all_labels,
+            x_value=with_params_extracted_forest_sizes,
+            xlabel='Number of trees extracted',
+            ylabel='Coherence',
+            title='Coherence values of {}'.format(args.dataset_name))
+        logger.info(f'Computing preds coherence plot...')
 
     logger.info('Done.')
diff --git a/code/train.py b/code/train.py
index 07e3a74..189aac1 100644
--- a/code/train.py
+++ b/code/train.py
@@ -55,12 +55,6 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb
 
     trainer = Trainer(dataset)
 
-    # if parameters['extraction_strategy'] == 'ensemble':
-    if False:
-        library = EnsembleSelectionForestRegressor.generate_library(dataset.X_train, dataset.y_train, random_state=seed)
-    else:
-        library = None
-
     if parameters['extraction_strategy'] == 'random':
         pretrained_model_parameters = ModelParameters(
             extracted_forest_size=parameters['forest_size'],
@@ -71,7 +65,7 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb
             hyperparameters=hyperparameters,
             extraction_strategy=parameters['extraction_strategy']
         )
-        pretrained_estimator = ModelFactory.build(dataset.task, pretrained_model_parameters, library=library)
+        pretrained_estimator = ModelFactory.build(dataset.task, pretrained_model_parameters)
         pretraned_trainer = Trainer(dataset)
         pretraned_trainer.init(pretrained_estimator, subsets_used=parameters['subsets_used'])
         pretrained_estimator.fit(
@@ -85,8 +79,9 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb
     if parameters['extraction_strategy'] != 'none':
         with tqdm_joblib(tqdm(total=len(parameters['extracted_forest_size']), disable=not verbose)) as extracted_forest_size_job_pb:
             Parallel(n_jobs=-1)(delayed(extracted_forest_size_job)(extracted_forest_size_job_pb, parameters['extracted_forest_size'][i],
-                models_dir, seed, parameters, dataset, hyperparameters, experiment_id, trainer, library,
-                pretrained_estimator=pretrained_estimator, pretrained_model_parameters=pretrained_model_parameters)
+                models_dir, seed, parameters, dataset, hyperparameters, experiment_id, trainer,
+                pretrained_estimator=pretrained_estimator, pretrained_model_parameters=pretrained_model_parameters,
+                use_distillation=parameters['extraction_strategy'] == 'omp_distillation')
                 for i in range(len(parameters['extracted_forest_size'])))
     else:
         forest_size = hyperparameters['n_estimators']
@@ -118,7 +113,7 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb
             )
             model_parameters.save(sub_models_dir, experiment_id)
 
-            model = ModelFactory.build(dataset.task, model_parameters, library=library)
+            model = ModelFactory.build(dataset.task, model_parameters)
 
             trainer.init(model, subsets_used=parameters['subsets_used'])
             trainer.train(model)
@@ -127,8 +122,8 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb
     seed_job_pb.update(1)
 
 def extracted_forest_size_job(extracted_forest_size_job_pb, extracted_forest_size, models_dir,
-    seed, parameters, dataset, hyperparameters, experiment_id, trainer, library,
-    pretrained_estimator=None, pretrained_model_parameters=None):
+    seed, parameters, dataset, hyperparameters, experiment_id, trainer,
+    pretrained_estimator=None, pretrained_model_parameters=None, use_distillation=False):
 
     logger = LoggerFactory.create(LOG_PATH, 'training_seed{}_extracted_forest_size{}_ti{}'.format(
         seed, extracted_forest_size, threading.get_ident()))
@@ -163,13 +158,14 @@ def extracted_forest_size_job(extracted_forest_size_job_pb, extracted_forest_siz
             extraction_strategy=parameters['extraction_strategy']
         )
         model_parameters.save(sub_models_dir, experiment_id)
-        model = ModelFactory.build(dataset.task, model_parameters, library=library)
+        model = ModelFactory.build(dataset.task, model_parameters)
     else:
         model = copy.deepcopy(pretrained_estimator)
         pretrained_model_parameters.save(sub_models_dir, experiment_id)
 
     trainer.init(model, subsets_used=parameters['subsets_used'])
-    trainer.train(model, extracted_forest_size=extracted_forest_size)
+    trainer.train(model, extracted_forest_size=extracted_forest_size, seed=seed,
+        use_distillation=use_distillation)
     trainer.compute_results(model, sub_models_dir)
 
 """
@@ -247,8 +243,8 @@ if __name__ == "__main__":
     else:
         parameters = args.__dict__
 
-    if parameters['extraction_strategy'] not in ['omp', 'random', 'none', 'similarity_similarities', 'similarity_predictions', 'kmeans', 'ensemble']:
-        raise ValueError('Specified extraction strategy {} is not supported.'.format(parameters.extraction_strategy))
+    if parameters['extraction_strategy'] not in ['omp', 'omp_distillation', 'random', 'none', 'similarity_similarities', 'similarity_predictions', 'kmeans', 'ensemble']:
+        raise ValueError('Specified extraction strategy {} is not supported.'.format(parameters['extraction_strategy']))
 
     pathlib.Path(parameters['models_dir']).mkdir(parents=True, exist_ok=True)
 
diff --git a/scripts/run_stage5_experiments.sh b/scripts/run_stage5_experiments.sh
index 4bd3711..fe8c3d6 100755
--- a/scripts/run_stage5_experiments.sh
+++ b/scripts/run_stage5_experiments.sh
@@ -1,16 +1,27 @@
 #!/bin/bash
 core_number=5
-core_number_sota=50
-walltime=1:00
+core_number_sota=5
+walltime=5:00
 walltime_sota=5:00
 seeds='1 2 3 4 5'
 
-for dataset in kin8nm kr-vs-kp spambase steel-plates diabetes diamonds boston california_housing
+for dataset in boston diabetes linnerud breast_cancer california_housing diamonds steel-plates kr-vs-kp kin8nm spambase musk gamma lfw_pairs
 do
-    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=none --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=1 --models_dir=models/$dataset/stage5 --subsets_used train+dev,train+dev"
-    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=random --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=2 --models_dir=models/$dataset/stage5 --subsets_used train+dev,train+dev"
-    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=omp --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=3 --models_dir=models/$dataset/stage5 --subsets_used train+dev,train+dev"
-    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=similarity --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=4 --models_dir=models/$dataset/stage5 --subsets_used train+dev,train+dev"
-    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=kmeans --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=5 --models_dir=models/$dataset/stage5 --subsets_used train+dev,train+dev"
-    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=ensemble --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=6 --models_dir=models/$dataset/stage5 --subsets_used train+dev,train+dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=none --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=1 --models_dir=models/$dataset/stage5_new --subsets_used train+dev,train+dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=random --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=2 --models_dir=models/$dataset/stage5_new --subsets_used train+dev,train+dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=omp --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=3 --models_dir=models/$dataset/stage5_new --subsets_used train+dev,train+dev --normalize_D"
+    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=omp_distillation --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=4 --models_dir=models/$dataset/stage5_new --subsets_used train+dev,train+dev --normalize_D"
+    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=kmeans --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=5 --models_dir=models/$dataset/stage5_new --subsets_used train+dev,train+dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=similarity_similarities --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=6 --models_dir=models/$dataset/stage5_new --subsets_used train+dev,train+dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=similarity_predictions --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=7 --models_dir=models/$dataset/stage5_new --subsets_used train+dev,train+dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=ensemble --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=8 --models_dir=models/$dataset/stage5_new --subsets_used train+dev,train+dev"
+
+    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=none --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=9 --models_dir=models/$dataset/stage5_new --subsets_used train,dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=random --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=10 --models_dir=models/$dataset/stage5_new --subsets_used train,dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=omp --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=11 --models_dir=models/$dataset/stage5_new --subsets_used train,dev --normalize_D"
+    oarsub -p "(gpu is null)" -l /core=$core_number,walltime=$walltime "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=omp_distillation --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=12 --models_dir=models/$dataset/stage5_new --subsets_used train,dev --normalize_D"
+    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=kmeans --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=13 --models_dir=models/$dataset/stage5_new --subsets_used train,dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=similarity_similarities --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=14 --models_dir=models/$dataset/stage5_new --subsets_used train,dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=similarity_predictions --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=15 --models_dir=models/$dataset/stage5_new --subsets_used train,dev"
+    oarsub -p "(gpu is null)" -l /core=$core_number_sota,walltime=$walltime_sota "conda activate test_env && python code/train.py --dataset_name=$dataset --seeds $seeds --extraction_strategy=ensemble --extracted_forest_size_stop=1.0 --extracted_forest_size_samples=30 --experiment_id=16 --models_dir=models/$dataset/stage5_new --subsets_used train,dev"
 done
-- 
GitLab