From 0714d5d520b4b6950338d41fcc6b9d2f5f65a6f6 Mon Sep 17 00:00:00 2001 From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr> Date: Wed, 15 Jan 2020 14:01:46 +0100 Subject: [PATCH] Didi some debug --- ...ltiview_utils.py => jumbo_fusion_utils.py} | 5 ++--- .../multiview_classifiers/svm_jumbo_fusion.py | 6 ++--- .../test_additions/test_jumbo_fusion_utils.py | 22 +++++++++++++++++++ 3 files changed, 27 insertions(+), 6 deletions(-) rename multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/{jumbo_multiview_utils.py => jumbo_fusion_utils.py} (95%) create mode 100644 multiview_platform/tests/test_multiview_classifiers/test_additions/test_jumbo_fusion_utils.py diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/jumbo_multiview_utils.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/jumbo_fusion_utils.py similarity index 95% rename from multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/jumbo_multiview_utils.py rename to multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/jumbo_fusion_utils.py index bceee169..b20804bd 100644 --- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/jumbo_multiview_utils.py +++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/jumbo_fusion_utils.py @@ -16,8 +16,8 @@ class BaseJumboFusion(LateFusionClassifier): self.distribs += [CustomRandint(1,10)] self.nb_monoview_per_view = nb_monoview_per_view - def set_params(self, **params): - self.nb_monoview_per_view = params["nb_monoview_per_view"] + def set_params(self, nb_monoview_per_view=1, **params): + self.nb_monoview_per_view = nb_monoview_per_view super(BaseJumboFusion, self).set_params(**params) def predict(self, X, example_indices=None, view_indices=None): @@ -28,7 +28,6 @@ class BaseJumboFusion(LateFusionClassifier): def fit(self, X, y, train_indices=None, view_indices=None): train_indices, view_indices = get_examples_views_indices(X, train_indices, view_indices) self.init_classifiers(len(view_indices), nb_monoview_per_view=self.nb_monoview_per_view) - print(self.classifiers_names, self.nb_monoview_per_view) self.fit_monoview_estimators(X, y, train_indices=train_indices, view_indices=view_indices) monoview_decisions = self.predict_monoview(X, example_indices=train_indices, view_indices=view_indices) self.aggregation_estimator.fit(monoview_decisions, y[train_indices]) diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/svm_jumbo_fusion.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/svm_jumbo_fusion.py index f88c344d..3c0b9c95 100644 --- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/svm_jumbo_fusion.py +++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/svm_jumbo_fusion.py @@ -1,6 +1,6 @@ from sklearn.svm import SVC -from .additions.jumbo_multiview_utils import BaseJumboFusion +from .additions.jumbo_fusion_utils import BaseJumboFusion from ..monoview.monoview_utils import CustomUniform, CustomRandint classifier_class_name = "SVMJumboFusion" @@ -17,9 +17,9 @@ class SVMJumboFusion(BaseJumboFusion): self.distribs += [CustomUniform(), ["rbf", "poly", "linear"], CustomRandint(2, 5)] self.aggregation_estimator = SVC(C=C, kernel=kernel, degree=degree) - def set_params(self, **params): + def set_params(self, C=1.0, kernel="rbf", degree=1, **params): super(SVMJumboFusion, self).set_params(**params) - self.aggregation_estimator.set_params(**dict((key, value) for key, value in params.items() if key in ["C", "kernel", "degree"])) + self.aggregation_estimator.set_params(C=C, kernel=kernel, degree=degree) return self diff --git a/multiview_platform/tests/test_multiview_classifiers/test_additions/test_jumbo_fusion_utils.py b/multiview_platform/tests/test_multiview_classifiers/test_additions/test_jumbo_fusion_utils.py new file mode 100644 index 00000000..9e242ed8 --- /dev/null +++ b/multiview_platform/tests/test_multiview_classifiers/test_additions/test_jumbo_fusion_utils.py @@ -0,0 +1,22 @@ +import unittest +import numpy as np + +import multiview_platform.mono_multi_view_classifiers.multiview_classifiers.additions.jumbo_fusion_utils as ju + + +class FakeDataset(): + + def __init__(self, views, labels): + self.nb_views = views.shape[0] + self.dataset_length = views.shape[2] + self.views = views + self.labels = labels + + def get_v(self, view_index, example_indices): + return self.views[view_index, example_indices] + + def get_nb_class(self, example_indices): + return np.unique(self.labels[example_indices]) + + +#TODO \ No newline at end of file -- GitLab