diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/jumbo_multiview_utils.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/jumbo_fusion_utils.py similarity index 95% rename from multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/jumbo_multiview_utils.py rename to multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/jumbo_fusion_utils.py index bceee169bf991f92b5a6ab505cac157ab2e9d561..b20804bd5d7be25238c872ec1e80d760e1b4a2d9 100644 --- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/jumbo_multiview_utils.py +++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/additions/jumbo_fusion_utils.py @@ -16,8 +16,8 @@ class BaseJumboFusion(LateFusionClassifier): self.distribs += [CustomRandint(1,10)] self.nb_monoview_per_view = nb_monoview_per_view - def set_params(self, **params): - self.nb_monoview_per_view = params["nb_monoview_per_view"] + def set_params(self, nb_monoview_per_view=1, **params): + self.nb_monoview_per_view = nb_monoview_per_view super(BaseJumboFusion, self).set_params(**params) def predict(self, X, example_indices=None, view_indices=None): @@ -28,7 +28,6 @@ class BaseJumboFusion(LateFusionClassifier): def fit(self, X, y, train_indices=None, view_indices=None): train_indices, view_indices = get_examples_views_indices(X, train_indices, view_indices) self.init_classifiers(len(view_indices), nb_monoview_per_view=self.nb_monoview_per_view) - print(self.classifiers_names, self.nb_monoview_per_view) self.fit_monoview_estimators(X, y, train_indices=train_indices, view_indices=view_indices) monoview_decisions = self.predict_monoview(X, example_indices=train_indices, view_indices=view_indices) self.aggregation_estimator.fit(monoview_decisions, y[train_indices]) diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/svm_jumbo_fusion.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/svm_jumbo_fusion.py index f88c344d095fa1da6a187c4eea6e787f1fb72c69..3c0b9c95db211e836a4b54c6bc0d1e1c6f6adfca 100644 --- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/svm_jumbo_fusion.py +++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/svm_jumbo_fusion.py @@ -1,6 +1,6 @@ from sklearn.svm import SVC -from .additions.jumbo_multiview_utils import BaseJumboFusion +from .additions.jumbo_fusion_utils import BaseJumboFusion from ..monoview.monoview_utils import CustomUniform, CustomRandint classifier_class_name = "SVMJumboFusion" @@ -17,9 +17,9 @@ class SVMJumboFusion(BaseJumboFusion): self.distribs += [CustomUniform(), ["rbf", "poly", "linear"], CustomRandint(2, 5)] self.aggregation_estimator = SVC(C=C, kernel=kernel, degree=degree) - def set_params(self, **params): + def set_params(self, C=1.0, kernel="rbf", degree=1, **params): super(SVMJumboFusion, self).set_params(**params) - self.aggregation_estimator.set_params(**dict((key, value) for key, value in params.items() if key in ["C", "kernel", "degree"])) + self.aggregation_estimator.set_params(C=C, kernel=kernel, degree=degree) return self diff --git a/multiview_platform/tests/test_multiview_classifiers/test_additions/test_jumbo_fusion_utils.py b/multiview_platform/tests/test_multiview_classifiers/test_additions/test_jumbo_fusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9e242ed89bd067148b0d4caa5da39f4057d04c26 --- /dev/null +++ b/multiview_platform/tests/test_multiview_classifiers/test_additions/test_jumbo_fusion_utils.py @@ -0,0 +1,22 @@ +import unittest +import numpy as np + +import multiview_platform.mono_multi_view_classifiers.multiview_classifiers.additions.jumbo_fusion_utils as ju + + +class FakeDataset(): + + def __init__(self, views, labels): + self.nb_views = views.shape[0] + self.dataset_length = views.shape[2] + self.views = views + self.labels = labels + + def get_v(self, view_index, example_indices): + return self.views[view_index, example_indices] + + def get_nb_class(self, example_indices): + return np.unique(self.labels[example_indices]) + + +#TODO \ No newline at end of file