From b5d7632711ec5e5cedc7568288b5a43e764170f7 Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Thu, 18 Feb 2021 10:21:55 -0500
Subject: [PATCH] Updated hyper params and figs

---
 config_files/config_test.yml                  |   9 +-
 summit/multiview_platform/exec_classif.py     |  17 +-
 .../result_analysis/error_analysis.py         |  15 +-
 .../result_analysis/execution.py              |   2 +-
 .../result_analysis/metric_analysis.py        |  10 +-
 .../multiview_platform/utils/configuration.py |   1 -
 summit/multiview_platform/utils/execution.py  |   2 -
 .../utils/hyper_parameter_search.py           | 223 +++++-------------
 summit/tests/test_utils/test_base.py          |   2 -
 .../test_utils/test_hyper_parameter_search.py |   6 +
 10 files changed, 94 insertions(+), 193 deletions(-)

diff --git a/config_files/config_test.yml b/config_files/config_test.yml
index 4c075335..478e0e08 100644
--- a/config_files/config_test.yml
+++ b/config_files/config_test.yml
@@ -21,8 +21,8 @@ split: 0.8
 nb_folds: 2
 nb_class: 3
 classes:
-type: [ "multiview"]
-algos_monoview: ["decision_tree", "adaboost", ]
+type: [ "monoview"]
+algos_monoview: ["decision_tree", ]
 algos_multiview: ["weighted_linear_late_fusion"]
 stats_iter: 3
 metrics:
@@ -33,6 +33,11 @@ hps_type: "Random"
 hps_args:
   n_iter: 4
   equivalent_draws: False
+  decision_tree:
+    max_depth:
+      Randint:
+        low: 1
+        high: 10
 
 
 weighted_linear_early_fusion:
diff --git a/summit/multiview_platform/exec_classif.py b/summit/multiview_platform/exec_classif.py
index ca5e8978..6c75194a 100644
--- a/summit/multiview_platform/exec_classif.py
+++ b/summit/multiview_platform/exec_classif.py
@@ -107,9 +107,7 @@ def init_multiview_exps(classifier_names, views_dictionary, nb_class,
                                                         classifier_name]},
                                                     views_dictionary=views_dictionary)]
         elif hps_method == "Random":
-            hps_kwargs = dict((key, value)
-                              for key, value in hps_kwargs.items()
-                              if key in ["n_iter", "equivalent_draws"])
+            hps_kwargs = get_random_hps_args(hps_kwargs, classifier_name)
             multiview_arguments += [
                 gen_single_multiview_arg_dictionary(classifier_name,
                                                     arguments,
@@ -171,9 +169,7 @@ def init_monoview_exps(classifier_names,
                                                                 hps_kwargs[
                                                                     classifier_name]})
             elif hps_method == "Random":
-                hps_kwargs = dict((key, value)
-                                  for key, value in hps_kwargs.items()
-                                  if key in ["n_iter", "equivalent_draws"])
+                hps_kwargs = get_random_hps_args(hps_kwargs, classifier_name)
                 arguments = gen_single_monoview_arg_dictionary(classifier_name,
                                                                kwargs_init,
                                                                nb_class,
@@ -198,6 +194,15 @@ def init_monoview_exps(classifier_names,
     return monoview_arguments
 
 
+def get_random_hps_args(hps_args, classifier_name):
+    hps_dict = {}
+    for key, value in hps_args.items():
+        if key in ["n_iter", "equivalent_draws"]:
+            hps_dict[key] = value
+        if key==classifier_name:
+            hps_dict["param_distributions"] = value
+    return hps_dict
+
 def gen_single_monoview_arg_dictionary(classifier_name, arguments, nb_class,
                                        view_index, view_name, hps_kwargs):
     if classifier_name in arguments:
diff --git a/summit/multiview_platform/result_analysis/error_analysis.py b/summit/multiview_platform/result_analysis/error_analysis.py
index f26671a3..971f68b8 100644
--- a/summit/multiview_platform/result_analysis/error_analysis.py
+++ b/summit/multiview_platform/result_analysis/error_analysis.py
@@ -45,11 +45,11 @@ def get_sample_errors(groud_truth, results):
     return sample_errors
 
 
-def publish_sample_errors(sample_errors, directory, databaseName,
+def publish_sample_errors(sample_errors, directory, database_name,
                           labels_names, sample_ids, labels):  # pragma: no cover
     logging.info("Start:\t Label analysis figure generation")
 
-    base_file_name = os.path.join(directory, databaseName + "-")
+    base_file_name = os.path.join(directory, database_name + "-")
 
     nb_classifiers, nb_samples, classifiers_names, \
         data_2d, error_on_samples = gen_error_data(sample_errors)
@@ -58,7 +58,7 @@ def publish_sample_errors(sample_errors, directory, databaseName,
     np.savetxt(base_file_name + "bar_plot_data.csv", error_on_samples,
                delimiter=",")
 
-    plot_2d(data_2d, classifiers_names, nb_classifiers, base_file_name,
+    plot_2d(data_2d, classifiers_names, nb_classifiers, base_file_name, database_name,
             sample_ids=sample_ids, labels=labels)
 
     plot_errors_bar(error_on_samples, nb_samples,
@@ -69,7 +69,7 @@ def publish_sample_errors(sample_errors, directory, databaseName,
 
 def publish_all_sample_errors(iter_results, directory,
                               stats_iter,
-                              sample_ids, labels):  # pragma: no cover
+                              sample_ids, labels, data_base_name):  # pragma: no cover
     logging.info(
         "Start:\t Global label analysis figure generation")
 
@@ -82,7 +82,7 @@ def publish_all_sample_errors(iter_results, directory,
                delimiter=",")
 
     plot_2d(data, classifier_names, nb_classifiers,
-            os.path.join(directory, ""), stats_iter=stats_iter,
+            os.path.join(directory, ""), data_base_name, stats_iter=stats_iter,
             sample_ids=sample_ids, labels=labels)
     plot_errors_bar(error_on_samples, nb_samples, os.path.join(directory, ""),
                     sample_ids=sample_ids)
@@ -151,7 +151,7 @@ def gen_error_data_glob(iter_results, stats_iter):
         classifier_names
 
 
-def plot_2d(data, classifiers_names, nb_classifiers, file_name, labels=None,
+def plot_2d(data, classifiers_names, nb_classifiers, file_name, dataset_name, labels=None,
             stats_iter=1, use_plotly=True, sample_ids=None):  # pragma: no cover
     r"""Used to generate a 2D plot of the errors.
 
@@ -218,6 +218,9 @@ def plot_2d(data, classifiers_names, nb_classifiers, file_name, labels=None,
                           ticktext=["Always Wrong", "Always Right"]),
             reversescale=True), )
         fig.update_yaxes(title_text="Examples", showticklabels=True)
+        fig.update_layout(
+            title="Dataset : {} <br> Errors for each classifier <br> Generated on <a href='https://baptiste.bauvin.pages.lis-lab.fr/summit'>SuMMIT</a>.".format(
+                dataset_name))
         fig.update_layout(paper_bgcolor='rgba(0,0,0,0)',
                           plot_bgcolor='rgba(0,0,0,0)')
         fig.update_xaxes(showticklabels=True, )
diff --git a/summit/multiview_platform/result_analysis/execution.py b/summit/multiview_platform/result_analysis/execution.py
index 16ec3fe0..20ff793b 100644
--- a/summit/multiview_platform/result_analysis/execution.py
+++ b/summit/multiview_platform/result_analysis/execution.py
@@ -124,7 +124,7 @@ def analyze_all(iter_results, stats_iter, directory, data_base_name,
                                          data_base_name, stats_iter,
                                          label_names)
     publish_all_sample_errors(error_analysis, directory, stats_iter,
-                              sample_ids, labels)
+                              sample_ids, labels, data_base_name)
     publish_feature_importances(feature_importances, directory,
                                 data_base_name, feature_importances_stds)
     plot_durations(duration_means, directory, data_base_name, duration_stds)
diff --git a/summit/multiview_platform/result_analysis/metric_analysis.py b/summit/multiview_platform/result_analysis/metric_analysis.py
index 3948676b..560976d5 100644
--- a/summit/multiview_platform/result_analysis/metric_analysis.py
+++ b/summit/multiview_platform/result_analysis/metric_analysis.py
@@ -106,8 +106,8 @@ def publish_metrics_graphs(metrics_scores, directory, database_name,
                                           class_metric_scores[metric_name])
 
         plot_metric_scores(train_scores, test_scores, classifier_names,
-                           nb_results, metric_name, file_name,
-                           tag=" " + " vs ".join(labels_names))
+                           nb_results, metric_name, file_name, database_name,
+                           tag=" vs ".join(labels_names))
 
         class_file_name = file_name+"-class"
         plot_class_metric_scores(class_test_scores, class_file_name,
@@ -137,7 +137,7 @@ def publish_all_metrics_scores(iter_results, class_iter_results, directory,
         nb_results = classifier_names.shape[0]
 
         plot_metric_scores(train, test, classifier_names, nb_results,
-                           metric_name, file_name, tag=" averaged",
+                           metric_name, file_name, data_base_name, tag="Averaged",
                            train_STDs=train_std, test_STDs=test_std)
         results += [[classifier_name, metric_name, test_mean, test_std]
                     for classifier_name, test_mean, test_std
@@ -186,7 +186,7 @@ def init_plot(results, metric_name, metric_dataframe,
 
 def plot_metric_scores(train_scores, test_scores, names, nb_results,
                        metric_name,
-                       file_name,
+                       file_name, dataset_name,
                        tag="", train_STDs=None, test_STDs=None,
                        use_plotly=True):  # pragma: no cover
     r"""Used to plot and save the score barplot for a specific metric.
@@ -272,7 +272,7 @@ def plot_metric_scores(train_scores, test_scores, names, nb_results,
         ))
 
         fig.update_layout(
-            title=metric_name + "<br>" + tag + " scores for each classifier")
+            title="Dataset : {}, metric : {}, task : {} <br> Scores for each classifier <br> Generated on <a href='https://baptiste.bauvin.pages.lis-lab.fr/summit'>SuMMIT</a>.".format(dataset_name, metric_name, tag))
         fig.update_layout(paper_bgcolor='rgba(0,0,0,0)',
                           plot_bgcolor='rgba(0,0,0,0)')
         plotly.offline.plot(fig, filename=file_name + ".html", auto_open=False)
diff --git a/summit/multiview_platform/utils/configuration.py b/summit/multiview_platform/utils/configuration.py
index 75bd4b02..9c79b83b 100644
--- a/summit/multiview_platform/utils/configuration.py
+++ b/summit/multiview_platform/utils/configuration.py
@@ -3,7 +3,6 @@ import os
 import yaml
 
 package_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-print(package_path)
 
 def get_the_args(path_to_config_file=os.path.join(os.path.dirname(package_path), "config_files", "config.yml")):
     """
diff --git a/summit/multiview_platform/utils/execution.py b/summit/multiview_platform/utils/execution.py
index 459335a0..4c2e94b7 100644
--- a/summit/multiview_platform/utils/execution.py
+++ b/summit/multiview_platform/utils/execution.py
@@ -321,8 +321,6 @@ def find_dataset_names(path, type, names):
      the needed dataset names."""
     package_path = os.path.dirname(
         os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
-    print(package_path, os.path.isdir(path),
-          os.path.isdir(os.path.join(package_path, path)), )
     if os.path.isdir(path):
         pass
     elif os.path.isdir(os.path.join(package_path, path)):
diff --git a/summit/multiview_platform/utils/hyper_parameter_search.py b/summit/multiview_platform/utils/hyper_parameter_search.py
index 0fd65b93..84211acb 100644
--- a/summit/multiview_platform/utils/hyper_parameter_search.py
+++ b/summit/multiview_platform/utils/hyper_parameter_search.py
@@ -28,6 +28,25 @@ from .organization import secure_file_path
 
 class HPSearch:
 
+    def translate_param_distribs(self, param_distribs):
+        translated_params = {}
+        if param_distribs is None:
+            return translated_params
+        for param_name, value in param_distribs.items():
+            if type(value) == list:
+                translated_params[param_name] = value
+            elif type(value)==dict:
+                if "Uniform" in value.keys():
+                    distrib = self.translate_uniform(value["Uniform"])
+                elif "Randint" in value.keys():
+                    distrib = self.translate_randint(value["Randint"])
+                else:
+                    distrib=value
+                translated_params[param_name] = distrib
+            else:
+                translated_params[param_name] = value
+        return translated_params
+
     def get_scoring(self, metric):
         if isinstance(metric, dict):
             metric_module, metric_kwargs = get_metric(metric)
@@ -138,13 +157,15 @@ class Random(RandomizedSearchCV, HPSearch):
                  random_state=None, learning_indices=None, view_indices=None,
                  framework="monoview",
                  equivalent_draws=True, track_tracebacks=True):
-        if param_distributions is None:
-            param_distributions = self.get_param_distribs(estimator)
+        param_distributions = self.get_param_distribs(estimator, param_distributions)
+
+
         scoring = HPSearch.get_scoring(self, scoring)
         RandomizedSearchCV.__init__(self, estimator, n_iter=n_iter,
                                     param_distributions=param_distributions,
                                     refit=refit, n_jobs=n_jobs, scoring=scoring,
                                     cv=cv, random_state=random_state)
+
         self.framework = framework
         self.available_indices = learning_indices
         self.view_indices = view_indices
@@ -152,11 +173,22 @@ class Random(RandomizedSearchCV, HPSearch):
         self.track_tracebacks = track_tracebacks
         self.tracebacks = []
 
-    def get_param_distribs(self, estimator):
+    def translate_uniform(self, args):
+        return CustomUniform(**args)
+
+    def translate_randint(self, args):
+        return CustomRandint(**args)
+
+
+    def get_param_distribs(self, estimator, user_distribs):
+        user_distribs = self.translate_param_distribs(user_distribs)
         if isinstance(estimator, MultiClassWrapper):
-            return estimator.estimator.gen_distribs()
+            base_distribs = estimator.estimator.gen_distribs()
         else:
-            return estimator.gen_distribs()
+            base_distribs = estimator.gen_distribs()
+        for key, value in user_distribs.items():
+            base_distribs[key] = value
+        return base_distribs
 
     def fit(self, X, y=None, groups=None, **fit_params):  # pragma: no cover
         if self.framework == "monoview":
@@ -174,10 +206,6 @@ class Random(RandomizedSearchCV, HPSearch):
             ParameterSampler(self.param_distributions, self.n_iter,
                              random_state=self.random_state))
 
-    # def fit_multiview(self, X, y=None, groups=None, track_tracebacks=True,
-    #                   **fit_params):
-    #     n_splits = self.cv.get_n_splits(self.available_indices,
-    #                                     y[self.available_indices])
 
 
 class Grid(GridSearchCV, HPSearch):
@@ -208,153 +236,19 @@ class Grid(GridSearchCV, HPSearch):
         self.candidate_params = list(ParameterGrid(self.param_grid))
         self.n_iter = len(self.candidate_params)
 
+class CustomDist:
 
-# class ParameterSamplerGrid:
-#
-#     def __init__(self, param_distributions, n_iter):
-#         from math import floor
-#         n_points_per_param = int(n_iter **(1/len(param_distributions)))
-#         selected_params = dict((param_name, [])
-#                                for param_name in param_distributions.keys())
-#         for param_name, distribution in param_distributions.items():
-#             if isinstance(distribution, list):
-#                 if len(distribution)<n_points_per_param:
-#                     selected_params[param_name] = distribution
-#                 else:
-#                     index_step = floor(len(distribution)/n_points_per_param-2)
-#                     selected_params[param_name] = distribution[0]+[distribution[index*index_step+1]
-#                                                    for index
-# in range(n_points_per_param)]
-
-
-#
-# def hps_search():
-#     pass
-#
-# def grid_search(X, y, framework, random_state, output_file_name,
-#                   classifier_module,
-#                   classifier_name, folds=4, nb_cores=1,
-#                   metric=["accuracy_score", None],
-#                   n_iter=30, classifier_kwargs={}, learning_indices=None,
-#                   view_indices=None,
-#                   equivalent_draws=True, grid_search_config=None):
-#     """Used to perfom gridsearch on the classifiers"""
-#     pass
-
-
-# class RS(HPSSearch):
-#
-#     def __init__(self, X, y, framework, random_state, output_file_name,
-#                       classifier_module,
-#                       classifier_name, folds=4, nb_cores=1,
-#                       metric=["accuracy_score", None],
-#                       n_iter=30, classifier_kwargs={}, learning_indices=None,
-#                       view_indices=None,
-#                       equivalent_draws=True):
-#         HPSSearch.__init__()
-
-
-# def randomized_search(X, y, framework, random_state, output_file_name,
-#                       classifier_module,
-#                       classifier_name, folds=4, nb_cores=1,
-#                       metric=["accuracy_score", None],
-#                       n_iter=30, classifier_kwargs={}, learning_indices=None,
-#                       view_indices=None,
-#                       equivalent_draws=True):
-#     estimator = getattr(classifier_module, classifier_name)(
-#         random_state=random_state,
-#         **classifier_kwargs)
-#     params_dict = estimator.gen_distribs()
-#     estimator = get_mc_estim(estimator, random_state,
-#                              multiview=(framework == "multiview"),
-#                              y=y)
-#     if params_dict:
-#         metric_module, metric_kwargs = get_metric(metric)
-#         scorer = metric_module.get_scorer(**metric_kwargs)
-#         # nb_possible_combinations = compute_possible_combinations(params_dict)
-#         # n_iter_real = min(n_iter, nb_possible_combinations)
-#
-#         random_search = MultiviewCompatibleRandomizedSearchCV(estimator,
-#                                                               n_iter=n_iter,
-#                                                               param_distributions=params_dict,
-#                                                               refit=True,
-#                                                               n_jobs=nb_cores,
-#                                                               scoring=scorer,
-#                                                               cv=folds,
-#                                                               random_state=random_state,
-#                                                               learning_indices=learning_indices,
-#                                                               view_indices=view_indices,
-#                                                               framework=framework,
-#                                                               equivalent_draws=equivalent_draws)
-#         random_search.fit(X, y)
-#         return random_search.transform_results()
-#     else:
-#         best_estimator = estimator
-#         best_params = {}
-#         scores_array = {}
-#         params = {}
-#         test_folds_preds = np.zeros(10)#get_test_folds_preds(X, y, folds, best_estimator,
-#                                           # framework, learning_indices)
-#         return best_params, scores_array, params
-
-
-#
-# def spear_mint(dataset, classifier_name, views_indices=None, k_folds=None,
-#                n_iter=1,
-#                **kwargs):
-#     """Used to perform spearmint on the classifiers to optimize hyper parameters,
-#     longer than randomsearch (can't be parallelized)"""
-#     pass
-#
-#
-# def gen_heat_maps(params, scores_array, output_file_name):
-#     """Used to generate a heat map for each doublet of hyperparms
-#     optimized on the previous function"""
-#     nb_params = len(params)
-#     if nb_params > 2:
-#         combinations = itertools.combinations(range(nb_params), 2)
-#     elif nb_params == 2:
-#         combinations = [(0, 1)]
-#     else:
-#         combinations = [()]
-#     for combination in combinations:
-#         if combination:
-#             param_name1, param_array1 = params[combination[0]]
-#             param_name2, param_array2 = params[combination[1]]
-#         else:
-#             param_name1, param_array1 = params[0]
-#             param_name2, param_array2 = ("Control", np.array([0]))
-#
-#         param_array1_set = np.sort(np.array(list(set(param_array1))))
-#         param_array2_set = np.sort(np.array(list(set(param_array2))))
-#
-#         scores_matrix = np.zeros(
-#             (len(param_array2_set), len(param_array1_set))) - 0.1
-#         for param1, param2, score in zip(param_array1, param_array2,
-#                                          scores_array):
-#             param1_index, = np.where(param_array1_set == param1)
-#             param2_index, = np.where(param_array2_set == param2)
-#             scores_matrix[int(param2_index), int(param1_index)] = score
-#
-#         plt.figure(figsize=(8, 6))
-#         plt.subplots_adjust(left=.2, right=0.95, bottom=0.15, top=0.95)
-#         plt.imshow(scores_matrix, interpolation='nearest', cmap=plt.cm.hot,
-#                    )
-#         plt.xlabel(param_name1)
-#         plt.ylabel(param_name2)
-#         plt.colorbar()
-#         plt.xticks(np.arange(len(param_array1_set)), param_array1_set)
-#         plt.yticks(np.arange(len(param_array2_set)), param_array2_set,
-#                    rotation=45)
-#         plt.title('Validation metric')
-#         plt.savefig(
-#             output_file_name + "heat_map-" + param_name1 + "-" + param_name2 + ".png",
-#             transparent=True)
-#         plt.close()
-#
-
+    def multiply(self, random_number):
+        if self.multiplier == "e-":
+            return 10 ** -random_number
+        elif self.multiplier =="e":
+            return 10**random_number
+        elif type(self.multiplier) in [int, float]:
+            return self.multiplier*random_number
+        else:
+            return random_number
 
-class CustomRandint:
+class CustomRandint(CustomDist):
     """Used as a distribution returning a integer between low and high-1.
     It can be used with a multiplier agrument to be able to perform more complex generation
     for example 10 e -(randint)"""
@@ -366,20 +260,14 @@ class CustomRandint:
         self.multiplier = multiplier
 
     def rvs(self, random_state=None):
-        randinteger = self.randint.rvs(random_state=random_state)
-        if self.multiplier == "e-":
-            return 10 ** -randinteger
-        else:
-            return randinteger
+        rand_integer = self.randint.rvs(random_state=random_state)
+        return self.multiply(rand_integer)
 
     def get_nb_possibilities(self):
-        if self.multiplier == "e-":
-            return abs(10 ** -self.low - 10 ** -self.high)
-        else:
-            return self.high - self.low
+        return self.high - self.low
 
 
-class CustomUniform:
+class CustomUniform(CustomDist):
     """Used as a distribution returning a float between loc and loc + scale..
         It can be used with a multiplier agrument to be able to perform more complex generation
         for example 10 e -(float)"""
@@ -390,10 +278,9 @@ class CustomUniform:
 
     def rvs(self, random_state=None):
         unif = self.uniform.rvs(random_state=random_state)
-        if self.multiplier == 'e-':
-            return 10 ** -unif
-        else:
-            return unif
+        return self.multiply(unif)
+
+
 
 
 def format_params(params, pref=""):
diff --git a/summit/tests/test_utils/test_base.py b/summit/tests/test_utils/test_base.py
index dc4ccb68..981118e9 100644
--- a/summit/tests/test_utils/test_base.py
+++ b/summit/tests/test_utils/test_base.py
@@ -143,7 +143,6 @@ class Test_ResultAnalyzer(unittest.TestCase):
                                  self.nb_cores, self.duration)
         RA.get_all_metrics_scores()
         string = RA.print_metric_score()
-        print(repr(string))
         self.assertEqual(string, '\n\n\tFor Accuracy score using {}, (higher is better) : \n\t\t- Score on train : 0.25\n\t\t- Score on test : 0.2692307692307692\n\n\tFor F1 score using average: micro, {} (higher is better) : \n\t\t- Score on train : 0.25\n\t\t- Score on test : 0.2692307692307692\n\nTest set confusion matrix : \n\n╒════════╤══════════╤══════════╤══════════╕\n│        │   class1 │   class2 │   class3 │\n╞════════╪══════════╪══════════╪══════════╡\n│ class1 │        3 │        1 │        2 │\n├────────┼──────────┼──────────┼──────────┤\n│ class2 │        3 │        2 │        2 │\n├────────┼──────────┼──────────┼──────────┤\n│ class3 │        3 │        8 │        2 │\n╘════════╧══════════╧══════════╧══════════╛\n\n')
 
     def test_get_db_config_string(self):
@@ -182,7 +181,6 @@ class Test_ResultAnalyzer(unittest.TestCase):
                                 self.labels, self.database_name,
                                 self.nb_cores, self.duration)
         str_analysis, img_analysis, metric_scores, class_metric_scores, conf_mat = RA.analyze()
-        print(repr(str_analysis))
         self.assertEqual(str_analysis, 'test2Database configuration : \n\t- Database name : test_database\ntest\t- Learning Rate : 0.48\n\t- Labels used : class1, class2, class3\n\t- Number of cross validation folds : 5\n\nClassifier configuration : \n\t- FakeClassifier with test1 : 10, test2 : test\n\t- Executed on 0.5 core(s) \n\t- Got configuration using randomized search with 6  iterations \n\n\n\tFor Accuracy score using {}, (higher is better) : \n\t\t- Score on train : 0.25\n\t\t- Score on test : 0.2692307692307692\n\n\tFor F1 score using average: micro, {} (higher is better) : \n\t\t- Score on train : 0.25\n\t\t- Score on test : 0.2692307692307692\n\nTest set confusion matrix : \n\n╒════════╤══════════╤══════════╤══════════╕\n│        │   class1 │   class2 │   class3 │\n╞════════╪══════════╪══════════╪══════════╡\n│ class1 │        3 │        1 │        2 │\n├────────┼──────────┼──────────┼──────────┤\n│ class2 │        3 │        2 │        2 │\n├────────┼──────────┼──────────┼──────────┤\n│ class3 │        3 │        8 │        2 │\n╘════════╧══════════╧══════════╧══════════╛\n\n\n\n Classification took -1 day, 23:59:56\n\n Classifier Interpretation : \n')
 
 
diff --git a/summit/tests/test_utils/test_hyper_parameter_search.py b/summit/tests/test_utils/test_hyper_parameter_search.py
index e1b848c5..5b635b07 100644
--- a/summit/tests/test_utils/test_hyper_parameter_search.py
+++ b/summit/tests/test_utils/test_hyper_parameter_search.py
@@ -29,6 +29,9 @@ class FakeEstim(BaseEstimator):
     def predict(self, X):
         return np.zeros(X.shape[0])
 
+    def gen_distribs(self):
+        return {"param1":"", "param2":""}
+
 
 class FakeEstimMV(BaseEstimator):
     def __init__(self, param1=None, param2=None):
@@ -45,6 +48,9 @@ class FakeEstimMV(BaseEstimator):
         else:
             return np.zeros(sample_indices.shape[0])
 
+    def gen_distribs(self):
+        return {"param1":"", "param2":""}
+
 
 class Test_Random(unittest.TestCase):
 
-- 
GitLab