From f8ca7c7c7fe403c9141f9a3c42178c26f6a6b6c9 Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Thu, 23 Jan 2020 14:25:33 +0100
Subject: [PATCH] Sklearn compatible now ?

---
 .../monoview/monoview_utils.py                            | 3 +++
 .../mono_multi_view_classifiers/result_analysis.py        | 2 +-
 .../mono_multi_view_classifiers/utils/execution.py        | 8 ++++++--
 3 files changed, 10 insertions(+), 3 deletions(-)

diff --git a/multiview_platform/mono_multi_view_classifiers/monoview/monoview_utils.py b/multiview_platform/mono_multi_view_classifiers/monoview/monoview_utils.py
index eb721f25..f2ce0112 100644
--- a/multiview_platform/mono_multi_view_classifiers/monoview/monoview_utils.py
+++ b/multiview_platform/mono_multi_view_classifiers/monoview/monoview_utils.py
@@ -193,6 +193,9 @@ class BaseMonoviewClassifier(BaseEstimator, ):#ClassifierMixin):
     def get_name_for_fusion(self):
         return self.__class__.__name__[:4]
 
+    def getInterpret(self, directory, y_test):
+        return ""
+
 
 def get_names(classed_list):
     return np.array([object_.__class__.__name__ for object_ in classed_list])
diff --git a/multiview_platform/mono_multi_view_classifiers/result_analysis.py b/multiview_platform/mono_multi_view_classifiers/result_analysis.py
index 2cbdd02a..6ff0f76b 100644
--- a/multiview_platform/mono_multi_view_classifiers/result_analysis.py
+++ b/multiview_platform/mono_multi_view_classifiers/result_analysis.py
@@ -218,7 +218,7 @@ def plot_2d(data, classifiers_names, nbClassifiers, nbExamples,
     ### The following part is used to generate an interactive graph.
     if use_plotly:
         label_index_list = np.concatenate([np.where(labels==i)[0] for i in np.unique(labels)]) #[np.where(labels==i)[0] for i in np.unique(labels)]
-        hover_text = [[example_ids[example_index] + " failed "+ str(stats_iter-data[example_index,classifier_index])+" time(s), labelled "+str(example_index)
+        hover_text = [[example_ids[example_index] + " failed "+ str(stats_iter-data[example_index,classifier_index])+" time(s), labelled "+str(labels[example_index])
                        for classifier_index in range(data.shape[1])]
                       for example_index in range(data.shape[0]) ]
         fig = plotly.graph_objs.Figure()
diff --git a/multiview_platform/mono_multi_view_classifiers/utils/execution.py b/multiview_platform/mono_multi_view_classifiers/utils/execution.py
index 649dfd29..45ef0a69 100644
--- a/multiview_platform/mono_multi_view_classifiers/utils/execution.py
+++ b/multiview_platform/mono_multi_view_classifiers/utils/execution.py
@@ -230,10 +230,14 @@ def gen_k_folds(stats_iter, nb_folds, stats_iter_random_states):
         for random_state in stats_iter_random_states:
             folds_list.append(
                 sklearn.model_selection.StratifiedKFold(n_splits=nb_folds,
-                                                        random_state=random_state))
+                                                        random_state=random_state,
+                                                        shuffle=True))
     else:
+        if isinstance(stats_iter_random_states, list):
+            stats_iter_random_states = stats_iter_random_states[0]
         folds_list = [sklearn.model_selection.StratifiedKFold(n_splits=nb_folds,
-                                                             random_state=stats_iter_random_states)]
+                                                             random_state=stats_iter_random_states,
+                                                              shuffle=True)]
     return folds_list
 
 
-- 
GitLab