From f8ca7c7c7fe403c9141f9a3c42178c26f6a6b6c9 Mon Sep 17 00:00:00 2001 From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr> Date: Thu, 23 Jan 2020 14:25:33 +0100 Subject: [PATCH] Sklearn compatible now ? --- .../monoview/monoview_utils.py | 3 +++ .../mono_multi_view_classifiers/result_analysis.py | 2 +- .../mono_multi_view_classifiers/utils/execution.py | 8 ++++++-- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/multiview_platform/mono_multi_view_classifiers/monoview/monoview_utils.py b/multiview_platform/mono_multi_view_classifiers/monoview/monoview_utils.py index eb721f25..f2ce0112 100644 --- a/multiview_platform/mono_multi_view_classifiers/monoview/monoview_utils.py +++ b/multiview_platform/mono_multi_view_classifiers/monoview/monoview_utils.py @@ -193,6 +193,9 @@ class BaseMonoviewClassifier(BaseEstimator, ):#ClassifierMixin): def get_name_for_fusion(self): return self.__class__.__name__[:4] + def getInterpret(self, directory, y_test): + return "" + def get_names(classed_list): return np.array([object_.__class__.__name__ for object_ in classed_list]) diff --git a/multiview_platform/mono_multi_view_classifiers/result_analysis.py b/multiview_platform/mono_multi_view_classifiers/result_analysis.py index 2cbdd02a..6ff0f76b 100644 --- a/multiview_platform/mono_multi_view_classifiers/result_analysis.py +++ b/multiview_platform/mono_multi_view_classifiers/result_analysis.py @@ -218,7 +218,7 @@ def plot_2d(data, classifiers_names, nbClassifiers, nbExamples, ### The following part is used to generate an interactive graph. if use_plotly: label_index_list = np.concatenate([np.where(labels==i)[0] for i in np.unique(labels)]) #[np.where(labels==i)[0] for i in np.unique(labels)] - hover_text = [[example_ids[example_index] + " failed "+ str(stats_iter-data[example_index,classifier_index])+" time(s), labelled "+str(example_index) + hover_text = [[example_ids[example_index] + " failed "+ str(stats_iter-data[example_index,classifier_index])+" time(s), labelled "+str(labels[example_index]) for classifier_index in range(data.shape[1])] for example_index in range(data.shape[0]) ] fig = plotly.graph_objs.Figure() diff --git a/multiview_platform/mono_multi_view_classifiers/utils/execution.py b/multiview_platform/mono_multi_view_classifiers/utils/execution.py index 649dfd29..45ef0a69 100644 --- a/multiview_platform/mono_multi_view_classifiers/utils/execution.py +++ b/multiview_platform/mono_multi_view_classifiers/utils/execution.py @@ -230,10 +230,14 @@ def gen_k_folds(stats_iter, nb_folds, stats_iter_random_states): for random_state in stats_iter_random_states: folds_list.append( sklearn.model_selection.StratifiedKFold(n_splits=nb_folds, - random_state=random_state)) + random_state=random_state, + shuffle=True)) else: + if isinstance(stats_iter_random_states, list): + stats_iter_random_states = stats_iter_random_states[0] folds_list = [sklearn.model_selection.StratifiedKFold(n_splits=nb_folds, - random_state=stats_iter_random_states)] + random_state=stats_iter_random_states, + shuffle=True)] return folds_list -- GitLab