From 455bdae9c42d9875a2618187a2713f0ac8c9203a Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Fri, 20 Mar 2020 10:04:01 -0400
Subject: [PATCH] Metric arguments ara now easier to set

---
 config_files/config_test.yml                  |  10 +-
 .../exec_classif.py                           | 199 +++++++++---------
 .../monoview/exec_classif_mono_view.py        |   4 +-
 .../monoview/monoview_utils.py                |   4 +-
 .../multiview/exec_multiview.py               |   2 +-
 .../multiview/multiview_utils.py              |   4 +-
 .../result_analysis/metric_analysis.py        |  26 +--
 .../mono_multi_view_classifiers/utils/base.py |  35 +--
 .../utils/configuration.py                    |   2 +-
 .../utils/execution.py                        |   7 +-
 .../utils/hyper_parameter_search.py           |   2 +-
 11 files changed, 147 insertions(+), 148 deletions(-)

diff --git a/config_files/config_test.yml b/config_files/config_test.yml
index 6cadcb60..d500c860 100644
--- a/config_files/config_test.yml
+++ b/config_files/config_test.yml
@@ -1,10 +1,10 @@
 # The base configuration of the benchmark
 log: True
-name: ["generated_dset",]
+name: ["digits",]
 label: "_"
 file_type: ".hdf5"
 views:
-pathf: "/home/baptiste/Documents/Gitwork/multiview_generator/demo/"
+pathf: "/home/baptiste/Documents/Datasets/Digits/"
 nice: 0
 random_state: 42
 nb_cores: 1
@@ -24,8 +24,10 @@ classes:
 type: ["monoview", "multiview"]
 algos_monoview: ["decision_tree" ]
 algos_multiview: ["weighted_linear_early_fusion","weighted_linear_late_fusion"]
-stats_iter: 15
-metrics: ["accuracy_score", "f1_score"]
+stats_iter: 3
+metrics:
+  accuracy_score: {}
+  f1_score: {}
 metric_princ: "accuracy_score"
 hps_type: "None"
 hps_args:
diff --git a/multiview_platform/mono_multi_view_classifiers/exec_classif.py b/multiview_platform/mono_multi_view_classifiers/exec_classif.py
index 2860d588..693f4955 100644
--- a/multiview_platform/mono_multi_view_classifiers/exec_classif.py
+++ b/multiview_platform/mono_multi_view_classifiers/exec_classif.py
@@ -501,7 +501,7 @@ def arange_metrics(metrics, metric_princ):
 
     Parameters
     ----------
-    metrics : list of lists
+    metrics : dict
         The metrics that will be used in the benchmark
 
     metric_princ : str
@@ -512,13 +512,11 @@ def arange_metrics(metrics, metric_princ):
     -------
     metrics : list of lists
         The metrics list, but arranged  so the first one is the principal one."""
-    if [metric_princ] in metrics:
-        metric_index = metrics.index([metric_princ])
-        first_metric = metrics[0]
-        metrics[0] = [metric_princ]
-        metrics[metric_index] = first_metric
+    if metric_princ in metrics:
+        metrics = dict((key, value) if not key==metric_princ else (key+"*", value) for key, value in metrics.items())
     else:
-        raise AttributeError(metric_princ + " not in metric pool")
+        raise AttributeError("{} not in metric pool ({})".format(metric_princ,
+                                                                 metrics))
     return metrics
 
 
@@ -874,100 +872,95 @@ def exec_classif(arguments):
     dataset_list = execution.find_dataset_names(args["pathf"],
                                                 args["file_type"],
                                                 args["name"])
-    if not args["add_noise"]:
-        args["noise_std"] = [0.0]
+    # if not args["add_noise"]:
+        # args["noise_std"] = [0.0]
     for dataset_name in dataset_list:
-        noise_results = []
-        for noise_std in args["noise_std"]:
-
-            directory = execution.init_log_file(dataset_name, args["views"],
-                                                args["file_type"],
-                                                args["log"], args["debug"],
-                                                args["label"],
-                                                args["res_dir"],
-                                                args["add_noise"], noise_std,
-                                                args)
-
-            random_state = execution.init_random_state(args["random_state"],
-                                                       directory)
-            stats_iter_random_states = execution.init_stats_iter_random_states(
-                stats_iter,
-                random_state)
-
-            get_database = execution.get_database_function(dataset_name,
-                                                           args["file_type"])
-
-            dataset_var, labels_dictionary, datasetname = get_database(
-                args["views"],
-                args["pathf"], dataset_name,
-                args["nb_class"],
-                args["classes"],
-                random_state,
-                args["full"],
-                args["add_noise"],
-                noise_std)
-            args["name"] = datasetname
-            splits = execution.gen_splits(dataset_var.get_labels(),
-                                          args["split"],
-                                          stats_iter_random_states)
-
-            # multiclass_labels, labels_combinations, indices_multiclass = multiclass.gen_multiclass_labels(
-            #     dataset_var.get_labels(), multiclass_method, splits)
-
-            k_folds = execution.gen_k_folds(stats_iter, args["nb_folds"],
-                                            stats_iter_random_states)
-
-            dataset_files = dataset.init_multiple_datasets(args["pathf"],
-                                                           args["name"],
-                                                           nb_cores)
-
-            views, views_indices, all_views = execution.init_views(dataset_var,
-                                                                   args[
-                                                                       "views"])
-            views_dictionary = dataset_var.get_view_dict()
-            nb_views = len(views)
-            nb_class = dataset_var.get_nb_class()
-
-            metrics = [metric.split(":") for metric in args["metrics"]]
-            if metrics == [["all"]]:
-                metrics_names = [name for _, name, isPackage
-                                 in pkgutil.iter_modules(
-                        [os.path.join(os.path.dirname(
-                            os.path.dirname(os.path.realpath(__file__))),
-                                      'metrics')]) if
-                                 not isPackage and name not in ["framework",
-                                                                "log_loss",
-                                                                "matthews_corrcoef",
-                                                                "roc_auc_score"]]
-                metrics = [[metricName, {}] for metricName in metrics_names]
-            metrics = arange_metrics(metrics, args["metric_princ"])
-            # TODO : Metric args
-            for metricIndex, metric in enumerate(metrics):
-                if len(metric) == 1:
-                    metrics[metricIndex] = [metric[0], {}]
-
-            benchmark = init_benchmark(cl_type, monoview_algos, multiview_algos,
-                                       args)
-            init_kwargs = init_kwargs_func(args, benchmark)
-            data_base_time = time.time() - start
-            argument_dictionaries = init_argument_dictionaries(
-                benchmark, views_dictionary,
-                nb_class, init_kwargs, hps_method, hps_kwargs)
-            # argument_dictionaries = initMonoviewExps(benchmark, viewsDictionary,
-            #                                         NB_CLASS, initKWARGS)
-            directories = execution.gen_direcorties_names(directory, stats_iter)
-            benchmark_argument_dictionaries = execution.gen_argument_dictionaries(
-                labels_dictionary, directories,
-                splits,
-                hps_method, args, k_folds,
-                stats_iter_random_states, metrics,
-                argument_dictionaries, benchmark,
-                views, views_indices)
-            results_mean_stds = exec_benchmark(
-                nb_cores, stats_iter,
-                benchmark_argument_dictionaries, directory, metrics,
-                dataset_var,
-                args["track_tracebacks"])
-            noise_results.append([noise_std, results_mean_stds])
-            plot_results_noise(directory, noise_results, metrics[0][0],
-                               dataset_name)
+        # noise_results = []
+        # for noise_std in args["noise_std"]:
+
+        directory = execution.init_log_file(dataset_name, args["views"],
+                                            args["file_type"],
+                                            args["log"], args["debug"],
+                                            args["label"],
+                                            args["res_dir"],
+                                            args)
+
+        random_state = execution.init_random_state(args["random_state"],
+                                                   directory)
+        stats_iter_random_states = execution.init_stats_iter_random_states(
+            stats_iter,
+            random_state)
+
+        get_database = execution.get_database_function(dataset_name,
+                                                       args["file_type"])
+
+        dataset_var, labels_dictionary, datasetname = get_database(
+            args["views"],
+            args["pathf"], dataset_name,
+            args["nb_class"],
+            args["classes"],
+            random_state,
+            args["full"],
+            )
+        args["name"] = datasetname
+        splits = execution.gen_splits(dataset_var.get_labels(),
+                                      args["split"],
+                                      stats_iter_random_states)
+
+        # multiclass_labels, labels_combinations, indices_multiclass = multiclass.gen_multiclass_labels(
+        #     dataset_var.get_labels(), multiclass_method, splits)
+
+        k_folds = execution.gen_k_folds(stats_iter, args["nb_folds"],
+                                        stats_iter_random_states)
+
+        dataset_files = dataset.init_multiple_datasets(args["pathf"],
+                                                       args["name"],
+                                                       nb_cores)
+
+        views, views_indices, all_views = execution.init_views(dataset_var,
+                                                               args[
+                                                                   "views"])
+        views_dictionary = dataset_var.get_view_dict()
+        nb_views = len(views)
+        nb_class = dataset_var.get_nb_class()
+
+        metrics = args["metrics"]
+        if metrics == "all":
+            metrics_names = [name for _, name, isPackage
+                             in pkgutil.iter_modules(
+                    [os.path.join(os.path.dirname(
+                        os.path.dirname(os.path.realpath(__file__))),
+                                  'metrics')]) if
+                             not isPackage and name not in ["framework",
+                                                            "log_loss",
+                                                            "matthews_corrcoef",
+                                                            "roc_auc_score"]]
+            metrics = dict((metric_name, {})
+                           for metric_name in metrics_names)
+        metrics = arange_metrics(metrics, args["metric_princ"])
+
+        benchmark = init_benchmark(cl_type, monoview_algos, multiview_algos,
+                                   args)
+        init_kwargs = init_kwargs_func(args, benchmark)
+        data_base_time = time.time() - start
+        argument_dictionaries = init_argument_dictionaries(
+            benchmark, views_dictionary,
+            nb_class, init_kwargs, hps_method, hps_kwargs)
+        # argument_dictionaries = initMonoviewExps(benchmark, viewsDictionary,
+        #                                         NB_CLASS, initKWARGS)
+        directories = execution.gen_direcorties_names(directory, stats_iter)
+        benchmark_argument_dictionaries = execution.gen_argument_dictionaries(
+            labels_dictionary, directories,
+            splits,
+            hps_method, args, k_folds,
+            stats_iter_random_states, metrics,
+            argument_dictionaries, benchmark,
+            views, views_indices)
+        results_mean_stds = exec_benchmark(
+            nb_cores, stats_iter,
+            benchmark_argument_dictionaries, directory, metrics,
+            dataset_var,
+            args["track_tracebacks"])
+            # noise_results.append([noise_std, results_mean_stds])
+            # plot_results_noise(directory, noise_results, metrics[0][0],
+            #                    dataset_name)
diff --git a/multiview_platform/mono_multi_view_classifiers/monoview/exec_classif_mono_view.py b/multiview_platform/mono_multi_view_classifiers/monoview/exec_classif_mono_view.py
index 0b7597d8..41835650 100644
--- a/multiview_platform/mono_multi_view_classifiers/monoview/exec_classif_mono_view.py
+++ b/multiview_platform/mono_multi_view_classifiers/monoview/exec_classif_mono_view.py
@@ -53,7 +53,7 @@ def exec_monoview_multicore(directory, name, labels_names,
 def exec_monoview(directory, X, Y, database_name, labels_names, classification_indices,
                   k_folds, nb_cores, databaseType, path,
                   random_state, hyper_param_search="randomized_search",
-                  metrics=[["accuracy_score", None]], n_iter=30, view_name="",
+                  metrics={"accuracy_score":{}}, n_iter=30, view_name="",
                   hps_kwargs={}, **args):
     logging.debug("Start:\t Loading data")
     kwargs, \
@@ -140,7 +140,7 @@ def exec_monoview(directory, X, Y, database_name, labels_names, classification_i
                                              classification_indices=classification_indices,
                                              k_folds=k_folds,
                                              hps_method=hyper_param_search,
-                                             metrics_list=metrics,
+                                             metrics_dict=metrics,
                                              n_iter=n_iter,
                                              class_label_names=labels_names,
                                              pred=full_pred,
diff --git a/multiview_platform/mono_multi_view_classifiers/monoview/monoview_utils.py b/multiview_platform/mono_multi_view_classifiers/monoview/monoview_utils.py
index a84fe0ef..076044b8 100644
--- a/multiview_platform/mono_multi_view_classifiers/monoview/monoview_utils.py
+++ b/multiview_platform/mono_multi_view_classifiers/monoview/monoview_utils.py
@@ -208,11 +208,11 @@ def get_accuracy_graph(plotted_data, classifier_name, file_name,
 class MonoviewResultAnalyzer(ResultAnalyser):
 
     def __init__(self, view_name, classifier_name, shape, classifier,
-                 classification_indices, k_folds, hps_method, metrics_list,
+                 classification_indices, k_folds, hps_method, metrics_dict,
                  n_iter, class_label_names, pred,
                  directory, base_file_name, labels, database_name, nb_cores, duration):
         ResultAnalyser.__init__(self, classifier, classification_indices,
-                                k_folds, hps_method, metrics_list, n_iter,
+                                k_folds, hps_method, metrics_dict, n_iter,
                                 class_label_names, pred,
                                 directory, base_file_name, labels,
                                 database_name, nb_cores, duration)
diff --git a/multiview_platform/mono_multi_view_classifiers/multiview/exec_multiview.py b/multiview_platform/mono_multi_view_classifiers/multiview/exec_multiview.py
index f3a9a04b..f300d63a 100644
--- a/multiview_platform/mono_multi_view_classifiers/multiview/exec_multiview.py
+++ b/multiview_platform/mono_multi_view_classifiers/multiview/exec_multiview.py
@@ -332,7 +332,7 @@ def exec_multiview(directory, dataset_var, name, classification_indices,
                                               classification_indices=classification_indices,
                                               k_folds=k_folds,
                                               hps_method=hps_method,
-                                              metrics_list=metrics,
+                                              metrics_dict=metrics,
                                               n_iter=n_iter,
                                               class_label_names=list(labels_dictionary.values()),
                                               pred=full_pred,
diff --git a/multiview_platform/mono_multi_view_classifiers/multiview/multiview_utils.py b/multiview_platform/mono_multi_view_classifiers/multiview/multiview_utils.py
index 2392f3bc..faae5749 100644
--- a/multiview_platform/mono_multi_view_classifiers/multiview/multiview_utils.py
+++ b/multiview_platform/mono_multi_view_classifiers/multiview/multiview_utils.py
@@ -177,13 +177,13 @@ class MultiviewResult(object):
 class MultiviewResultAnalyzer(ResultAnalyser):
 
     def __init__(self, view_names, classifier, classification_indices, k_folds,
-                 hps_method, metrics_list, n_iter, class_label_names,
+                 hps_method, metrics_dict, n_iter, class_label_names,
                  pred, directory, base_file_name, labels,
                  database_name, nb_cores, duration):
         if hps_method.endswith("equiv"):
             n_iter = n_iter*len(view_names)
         ResultAnalyser.__init__(self, classifier, classification_indices, k_folds,
-                                hps_method, metrics_list, n_iter, class_label_names,
+                                hps_method, metrics_dict, n_iter, class_label_names,
                                 pred, directory,
                                 base_file_name, labels, database_name,
                                 nb_cores, duration)
diff --git a/multiview_platform/mono_multi_view_classifiers/result_analysis/metric_analysis.py b/multiview_platform/mono_multi_view_classifiers/result_analysis/metric_analysis.py
index a0876f9c..32ac4830 100644
--- a/multiview_platform/mono_multi_view_classifiers/result_analysis/metric_analysis.py
+++ b/multiview_platform/mono_multi_view_classifiers/result_analysis/metric_analysis.py
@@ -13,7 +13,7 @@ def get_metrics_scores(metrics, results, label_names):
 
     Parameters
     ----------
-    metrics : list of lists
+    metrics : dict
         The metrics names with configuration metrics[i][0] = name of metric i
     results : list of MonoviewResult and MultiviewResults objects
         A list containing all the results for all the monoview experimentations.
@@ -36,32 +36,32 @@ def get_metrics_scores(metrics, results, label_names):
                         for classifier_result in results
                         if classifier_result.get_classifier_name()
                         not in classifier_names]
-    metrics_scores = dict((metric[0], pd.DataFrame(data=np.zeros((2,
+    metrics_scores = dict((metric, pd.DataFrame(data=np.zeros((2,
                                                                   len(
                                                                       classifier_names))),
                                                    index=["train", "test"],
                                                    columns=classifier_names))
-                          for metric in metrics)
-    class_metric_scores = dict((metric[0], pd.DataFrame(
+                          for metric in metrics.keys())
+    class_metric_scores = dict((metric, pd.DataFrame(
         index=pd.MultiIndex.from_product([["train", "test"], label_names]),
         columns=classifier_names, dtype=float))
                                for metric in metrics)
 
-    for metric in metrics:
+    for metric in metrics.keys():
         for classifier_result in results:
-            metrics_scores[metric[0]].loc[
+            metrics_scores[metric].loc[
                 "train", classifier_result.get_classifier_name()] = \
-            classifier_result.metrics_scores[metric[0]][0]
-            metrics_scores[metric[0]].loc[
+            classifier_result.metrics_scores[metric][0]
+            metrics_scores[metric].loc[
                 "test", classifier_result.get_classifier_name()] = \
-                classifier_result.metrics_scores[metric[0]][1]
+                classifier_result.metrics_scores[metric][1]
             for label_index, label_name in enumerate(label_names):
-                class_metric_scores[metric[0]].loc[(
+                class_metric_scores[metric].loc[(
                     "train", label_name),classifier_result.get_classifier_name()] = \
-                classifier_result.class_metric_scores[metric[0]][0][label_index]
-                class_metric_scores[metric[0]].loc[(
+                classifier_result.class_metric_scores[metric][0][label_index]
+                class_metric_scores[metric].loc[(
                     "test", label_name), classifier_result.get_classifier_name()] = \
-                    classifier_result.class_metric_scores[metric[0]][1][label_index]
+                    classifier_result.class_metric_scores[metric][1][label_index]
 
     return metrics_scores, class_metric_scores
 
diff --git a/multiview_platform/mono_multi_view_classifiers/utils/base.py b/multiview_platform/mono_multi_view_classifiers/utils/base.py
index 1d5e3e6d..65151021 100644
--- a/multiview_platform/mono_multi_view_classifiers/utils/base.py
+++ b/multiview_platform/mono_multi_view_classifiers/utils/base.py
@@ -140,17 +140,16 @@ def get_names(classed_list):
     return np.array([object_.__class__.__name__ for object_ in classed_list])
 
 
-def get_metric(metric_list):
+def get_metric(metrics_dict):
     """
     Fetches the metric module in the metrics package
     """
-    metric_module = getattr(metrics, metric_list[0][0])
-    if metric_list[0][1] is not None:
-        metric_kwargs = dict((index, metricConfig) for index, metricConfig in
-                             enumerate(metric_list[0][1]))
-    else:
-        metric_kwargs = {}
-    return metric_module, metric_kwargs
+    for metric_name, metric_kwargs in metrics_dict:
+        if metric_name.endswith("*"):
+            princ_metric_name = metric_name[:-1]
+            princ_metric_kwargs = metric_kwargs
+    metric_module = getattr(metrics, princ_metric_name)
+    return metric_module, princ_metric_kwargs
 
 
 class ResultAnalyser():
@@ -161,7 +160,7 @@ class ResultAnalyser():
     """
 
     def __init__(self, classifier, classification_indices, k_folds,
-                 hps_method, metrics_list, n_iter, class_label_names,
+                 hps_method, metrics_dict, n_iter, class_label_names,
                  pred, directory, base_file_name, labels,
                  database_name, nb_cores, duration):
         """
@@ -176,7 +175,7 @@ class ResultAnalyser():
 
         hps_method: string naming the hyper-parameter search method
 
-        metrics_list: list of the metrics to compute on the results
+        metrics_dict: list of the metrics to compute on the results
 
         n_iter: number of HPS iterations
 
@@ -200,7 +199,7 @@ class ResultAnalyser():
         self.train_indices, self.test_indices = classification_indices
         self.k_folds = k_folds
         self.hps_method = hps_method
-        self.metrics_list = metrics_list
+        self.metrics_dict = metrics_dict
         self.n_iter = n_iter
         self.class_label_names = class_label_names
         self.pred = pred
@@ -220,7 +219,7 @@ class ResultAnalyser():
         Returns
         -------
         """
-        for metric, metric_args in self.metrics_list:
+        for metric, metric_args in self.metrics_dict.items():
             class_train_scores, class_test_scores, train_score, test_score\
                 = self.get_metric_score(metric, metric_args)
             self.class_metric_scores[metric] = (class_train_scores,
@@ -242,7 +241,10 @@ class ResultAnalyser():
         -------
         train_score, test_score
         """
-        metric_module = getattr(metrics, metric)
+        if not metric.endswith("*"):
+            metric_module = getattr(metrics, metric)
+        else:
+            metric_module = getattr(metrics, metric[:-1])
         class_train_scores = []
         class_test_scores = []
         for label_value in np.unique(self.labels):
@@ -277,8 +279,11 @@ class ResultAnalyser():
         metric_score_string string formatting all metric results
         """
         metric_score_string = "\n\n"
-        for metric, metric_kwargs in self.metrics_list:
-            metric_module = getattr(metrics, metric)
+        for metric, metric_kwargs in self.metrics_dict.items():
+            if metric.endswith("*"):
+                metric_module = getattr(metrics, metric[:-1])
+            else:
+                metric_module = getattr(metrics, metric)
             metric_score_string += "\tFor {} : ".format(metric_module.get_config(
                 **metric_kwargs))
             metric_score_string += "\n\t\t- Score on train : {}".format(self.metric_scores[metric][0])
diff --git a/multiview_platform/mono_multi_view_classifiers/utils/configuration.py b/multiview_platform/mono_multi_view_classifiers/utils/configuration.py
index 9f349084..fcd62c6d 100644
--- a/multiview_platform/mono_multi_view_classifiers/utils/configuration.py
+++ b/multiview_platform/mono_multi_view_classifiers/utils/configuration.py
@@ -45,7 +45,7 @@ def pass_default_config(log=True,
                         algos_monoview=["all"],
                         algos_multiview=["svm_jumbo_fusion", ],
                         stats_iter=2,
-                        metrics=["accuracy_score", "f1_score"],
+                        metrics={"accuracy_score":{}, "f1_score":{}},
                         metric_princ="accuracy_score",
                         hps_type="Random",
                         hps_iter=1,
diff --git a/multiview_platform/mono_multi_view_classifiers/utils/execution.py b/multiview_platform/mono_multi_view_classifiers/utils/execution.py
index 9bf1b72d..99534ae2 100644
--- a/multiview_platform/mono_multi_view_classifiers/utils/execution.py
+++ b/multiview_platform/mono_multi_view_classifiers/utils/execution.py
@@ -119,7 +119,7 @@ def get_database_function(name, type_var):
 
 
 def init_log_file(name, views, cl_type, log, debug, label,
-                  result_directory, add_noise, noise_std, args):
+                  result_directory, args):
     r"""Used to init the directory where the preds will be stored and the log file.
 
     First this function will check if the result directory already exists (only one per minute is allowed).
@@ -153,16 +153,15 @@ def init_log_file(name, views, cl_type, log, debug, label,
     """
     if views is None:
         views = []
-    noise_string = "n_" + str(int(noise_std * 100))
     result_directory = os.path.join(os.path.dirname(
         os.path.dirname(os.path.dirname(os.path.realpath(__file__)))),
                                     result_directory)
     if debug:
-        result_directory = os.path.join(result_directory, name, noise_string,
+        result_directory = os.path.join(result_directory, name,
                                         "debug_started_" + time.strftime(
                                             "%Y_%m_%d-%H_%M_%S") + "_" + label)
     else:
-        result_directory = os.path.join(result_directory, name, noise_string,
+        result_directory = os.path.join(result_directory, name,
                                         "started_" + time.strftime(
                                             "%Y_%m_%d-%H_%M") + "_" + label)
     log_file_name = time.strftime("%Y_%m_%d-%H_%M") + "-" + ''.join(
diff --git a/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py b/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py
index d65af38d..8d126bbb 100644
--- a/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py
+++ b/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py
@@ -47,7 +47,7 @@ from .. import metrics
 class HPSearch:
 
     def get_scoring(self, metric):
-        if isinstance(metric, list):
+        if isinstance(metric, dict):
             metric_module, metric_kwargs = get_metric(metric)
             return metric_module.get_scorer(**metric_kwargs)
         else:
-- 
GitLab