Skip to content
Snippets Groups Projects
Commit 0714d5d5 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Didi some debug

parent bf72ef35
Branches
Tags
No related merge requests found
...@@ -16,8 +16,8 @@ class BaseJumboFusion(LateFusionClassifier): ...@@ -16,8 +16,8 @@ class BaseJumboFusion(LateFusionClassifier):
self.distribs += [CustomRandint(1,10)] self.distribs += [CustomRandint(1,10)]
self.nb_monoview_per_view = nb_monoview_per_view self.nb_monoview_per_view = nb_monoview_per_view
def set_params(self, **params): def set_params(self, nb_monoview_per_view=1, **params):
self.nb_monoview_per_view = params["nb_monoview_per_view"] self.nb_monoview_per_view = nb_monoview_per_view
super(BaseJumboFusion, self).set_params(**params) super(BaseJumboFusion, self).set_params(**params)
def predict(self, X, example_indices=None, view_indices=None): def predict(self, X, example_indices=None, view_indices=None):
...@@ -28,7 +28,6 @@ class BaseJumboFusion(LateFusionClassifier): ...@@ -28,7 +28,6 @@ class BaseJumboFusion(LateFusionClassifier):
def fit(self, X, y, train_indices=None, view_indices=None): def fit(self, X, y, train_indices=None, view_indices=None):
train_indices, view_indices = get_examples_views_indices(X, train_indices, view_indices) train_indices, view_indices = get_examples_views_indices(X, train_indices, view_indices)
self.init_classifiers(len(view_indices), nb_monoview_per_view=self.nb_monoview_per_view) self.init_classifiers(len(view_indices), nb_monoview_per_view=self.nb_monoview_per_view)
print(self.classifiers_names, self.nb_monoview_per_view)
self.fit_monoview_estimators(X, y, train_indices=train_indices, view_indices=view_indices) self.fit_monoview_estimators(X, y, train_indices=train_indices, view_indices=view_indices)
monoview_decisions = self.predict_monoview(X, example_indices=train_indices, view_indices=view_indices) monoview_decisions = self.predict_monoview(X, example_indices=train_indices, view_indices=view_indices)
self.aggregation_estimator.fit(monoview_decisions, y[train_indices]) self.aggregation_estimator.fit(monoview_decisions, y[train_indices])
......
from sklearn.svm import SVC from sklearn.svm import SVC
from .additions.jumbo_multiview_utils import BaseJumboFusion from .additions.jumbo_fusion_utils import BaseJumboFusion
from ..monoview.monoview_utils import CustomUniform, CustomRandint from ..monoview.monoview_utils import CustomUniform, CustomRandint
classifier_class_name = "SVMJumboFusion" classifier_class_name = "SVMJumboFusion"
...@@ -17,9 +17,9 @@ class SVMJumboFusion(BaseJumboFusion): ...@@ -17,9 +17,9 @@ class SVMJumboFusion(BaseJumboFusion):
self.distribs += [CustomUniform(), ["rbf", "poly", "linear"], CustomRandint(2, 5)] self.distribs += [CustomUniform(), ["rbf", "poly", "linear"], CustomRandint(2, 5)]
self.aggregation_estimator = SVC(C=C, kernel=kernel, degree=degree) self.aggregation_estimator = SVC(C=C, kernel=kernel, degree=degree)
def set_params(self, **params): def set_params(self, C=1.0, kernel="rbf", degree=1, **params):
super(SVMJumboFusion, self).set_params(**params) super(SVMJumboFusion, self).set_params(**params)
self.aggregation_estimator.set_params(**dict((key, value) for key, value in params.items() if key in ["C", "kernel", "degree"])) self.aggregation_estimator.set_params(C=C, kernel=kernel, degree=degree)
return self return self
......
import unittest
import numpy as np
import multiview_platform.mono_multi_view_classifiers.multiview_classifiers.additions.jumbo_fusion_utils as ju
class FakeDataset():
def __init__(self, views, labels):
self.nb_views = views.shape[0]
self.dataset_length = views.shape[2]
self.views = views
self.labels = labels
def get_v(self, view_index, example_indices):
return self.views[view_index, example_indices]
def get_nb_class(self, example_indices):
return np.unique(self.labels[example_indices])
#TODO
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment