Skip to content
Snippets Groups Projects
Commit 54217ad7 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Propagated multiclass support in late fusion

parent 43456222
No related branches found
No related tags found
No related merge requests found
Pipeline #4117 passed
Showing
with 49 additions and 34 deletions
......@@ -22,7 +22,7 @@ nb_folds: 2
nb_class: 3
classes:
type: ["multiview",]
algos_monoview: ["all" ]
algos_monoview: ["decision_tree" ]
algos_multiview: ["svm_jumbo_fusion",]
stats_iter: 2
metrics: ["accuracy_score", "f1_score"]
......
......@@ -93,8 +93,8 @@ def exec_monoview(directory, X, Y, name, labels_names, classification_indices,
classifier = get_mc_estim(getattr(classifier_module,
classifier_class_name)
(random_state, **cl_kwargs),
Y,
random_state)
random_state,
y=Y)
classifier.fit(X_train, y_train) # NB_CORES=nbCores,
logging.debug("Done:\t Training")
......
......@@ -273,7 +273,8 @@ def exec_multiview(directory, dataset_var, name, classification_indices, k_folds
classifier_config=classifier_config)
classifier = get_mc_estim(getattr(classifier_module, classifier_name)(random_state=random_state,
**classifier_config),
dataset_var.get_labels(), random_state, multiview=True,)
random_state, multiview=True,
y=dataset_var.get_labels())
logging.debug("Done:\t Optimizing hyperparameters")
logging.debug("Start:\t Fitting classifier")
classifier.fit(dataset_var, dataset_var.get_labels(), train_indices=learning_indices,
......
......@@ -125,7 +125,7 @@ def get_available_monoview_classifiers(need_probas=False):
available_classifiers = proba_classifiers
return available_classifiers
def get_monoview_classifier(classifier_name):
def get_monoview_classifier(classifier_name, multiclass=False):
classifier_module = getattr(monoview_classifiers, classifier_name)
classifier_class = getattr(classifier_module, classifier_module.classifier_class_name)
return classifier_class
......
......@@ -32,6 +32,7 @@ class DiversityFusionClassifier(BaseMultiviewClassifier,
train_indices, view_indices = get_examples_views_indices(X,
train_indices,
view_indices)
# TODO : Finer analysis, may support a bit of mutliclass
if np.unique(y[train_indices]).shape[0] > 2:
raise ValueError("Multiclass not supported, classes used : {}".format(np.unique(y[train_indices])))
if self.monoview_estimators is None:
......
......@@ -2,12 +2,13 @@ import inspect
from ...multiview.multiview_utils import get_monoview_classifier
from ...utils.multiclass import get_mc_estim
class BaseFusionClassifier():
def init_monoview_estimator(self, classifier_name, classifier_config,
classifier_index=None,):
classifier_index=None, multiclass=False):
if classifier_index is not None:
if classifier_config is not None:
classifier_configs = classifier_config[classifier_name]
......@@ -31,4 +32,7 @@ class BaseFusionClassifier():
random_state=self.random_state)
else:
estimator = get_monoview_classifier(classifier_name)()
return estimator
return get_mc_estim(estimator, random_state=self.random_state,
multiview=False, multiclass=multiclass)
......@@ -35,8 +35,13 @@ class BaseJumboFusion(LateFusionClassifier):
return self
def fit_monoview_estimators(self, X, y, train_indices=None, view_indices=None):
if np.unique(y).shape[0]>2:
multiclass=True
else:
multiclass=False
self.monoview_estimators = [[self.init_monoview_estimator(classifier_name,
self.classifier_configs[classifier_index])
self.classifier_configs[classifier_index],
multiclass=multiclass)
for classifier_index, classifier_name
in enumerate(self.classifiers_names)]
for _ in view_indices]
......
......@@ -97,7 +97,11 @@ class LateFusionClassifier(BaseMultiviewClassifier, BaseFusionClassifier):
train_indices, view_indices = get_examples_views_indices(X,
train_indices,
view_indices)
self.init_params(len(view_indices))
if np.unique(y).shape[0]>2:
multiclass=True
else:
multiclass=False
self.init_params(len(view_indices), multiclass)
if np.unique(y[train_indices]).shape[0] > 2:
raise ValueError("Multiclass not supported")
self.monoview_estimators = [monoview_estimator.fit(X.get_v(view_index, train_indices),
......@@ -107,7 +111,7 @@ class LateFusionClassifier(BaseMultiviewClassifier, BaseFusionClassifier):
self.monoview_estimators)]
return self
def init_params(self, nb_view):
def init_params(self, nb_view, mutliclass=False):
if self.weights is None:
self.weights = np.ones(nb_view) / nb_view
elif isinstance(self.weights, WeightDistribution):
......@@ -120,7 +124,8 @@ class LateFusionClassifier(BaseMultiviewClassifier, BaseFusionClassifier):
self.monoview_estimators = [
self.init_monoview_estimator(classifier_name,
self.classifier_configs[classifier_index],
classifier_index=classifier_index)
classifier_index=classifier_index,
multiclass=mutliclass)
for classifier_index, classifier_name
in enumerate(self.classifiers_names)]
......
......@@ -50,15 +50,12 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier, BaseFusionClassifier):
self.classed_params = []
self.weird_strings={}
def set_params(self, monoview_classifier_name=None, monoview_classifier_config=None, **params):
def set_params(self, monoview_classifier_name=None,
monoview_classifier_config=None, **params):
self.monoview_classifier_name = monoview_classifier_name
monoview_classifier_module = getattr(monoview_classifiers,
self.monoview_classifier_name)
monoview_classifier_class = getattr(monoview_classifier_module,
monoview_classifier_module.classifier_class_name)
self.monoview_classifier = monoview_classifier_class()
self.init_monoview_estimator(monoview_classifier_name,
self.monoview_classifier = self.init_monoview_estimator(monoview_classifier_name,
monoview_classifier_config)
self.monoview_classifier_config = self.monoview_classifier.get_params()
self.short_name = "early fusion " + self.monoview_classifier_name
return self
......@@ -74,9 +71,9 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier, BaseFusionClassifier):
if np.unique(y[train_indices]).shape[0] > 2 and \
not(isinstance(self.monoview_classifier, MultiClassWrapper)):
self.monoview_classifier = get_mc_estim(self.monoview_classifier,
y[train_indices],
self.random_state,
multiview=False)
multiview=False,
y=y[train_indices])
self.monoview_classifier.fit(X, y[train_indices])
return self
......@@ -100,10 +97,10 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier, BaseFusionClassifier):
X = self.hdf5_to_monoview(dataset, example_indices)
return example_indices, X
def hdf5_to_monoview(self, dataset, exmaples):
def hdf5_to_monoview(self, dataset, examples):
"""Here, we concatenate the views for the asked examples """
monoview_data = np.concatenate(
[dataset.get_v(view_idx, exmaples)
[dataset.get_v(view_idx, examples)
for view_weight, (index, view_idx)
in zip(self.view_weights, enumerate(self.view_indices))]
, axis=1)
......
......@@ -114,8 +114,9 @@ def randomized_search(X, y, framework, random_state, output_file_name, classifie
estimator = getattr(classifier_module, classifier_name)(random_state=random_state,
**classifier_kwargs)
params_dict = estimator.gen_distribs()
estimator = get_mc_estim(estimator, y, random_state,
multiview=(framework=="multiview"))
estimator = get_mc_estim(estimator, random_state,
multiview=(framework=="multiview"),
y=y)
if params_dict:
metric_module = getattr(metrics, metric[0])
if metric[1] is not None:
......
......@@ -110,7 +110,8 @@ from .dataset import get_examples_views_indices
# return False
def get_mc_estim(estimator, y, random_state, multiview=False):
def get_mc_estim(estimator, random_state, y=None, multiview=False,
multiclass=False):
r"""Used to get a multiclass-compatible estimator if the one in param does not natively support multiclass.
If perdict_proba is available in the asked estimator, a One Versus Rest wrapper is returned,
else, a One Versus One wrapper is returned.
......@@ -133,7 +134,7 @@ def get_mc_estim(estimator, y, random_state, multiview=False):
estimator : sklearn-like estimator
Either the aksed estimator, or a multiclass-compatible wrapper over the asked estimator
"""
if np.unique(y).shape[0]>2:
if (y is not None and np.unique(y).shape[0]>2) or multiclass :
if not clone(estimator).accepts_multi_class(random_state):
if hasattr(estimator, "predict_proba"):
if multiview:
......
......@@ -40,34 +40,34 @@ class Test_get_mc_estim(unittest.TestCase):
def test_biclass(self):
y = self.random_state.randint(0,2,10)
estimator="Test"
returned_estimator = get_mc_estim(estimator, y, self.random_state,)
returned_estimator = get_mc_estim(estimator, self.random_state, y=y)
self.assertEqual(returned_estimator, estimator)
def test_multiclass_native(self):
estimator = FakeEstimNative()
returned_estimator = get_mc_estim(estimator, self.y, self.random_state)
returned_estimator = get_mc_estim(estimator, self.random_state, y=self.y)
self.assertIsInstance(returned_estimator, FakeEstimNative)
def test_multiclass_ovo(self):
estimator = FakeNonProbaEstim()
returned_estimator = get_mc_estim(estimator, self.y, self.random_state)
returned_estimator = get_mc_estim(estimator, self.random_state, y=self.y)
self.assertIsInstance(returned_estimator, OVOWrapper)
def test_multiclass_ovr(self):
estimator = FakeProbaEstim()
returned_estimator = get_mc_estim(estimator, self.y, self.random_state)
returned_estimator = get_mc_estim(estimator, self.random_state, y=self.y)
self.assertIsInstance(returned_estimator, OVRWrapper)
def test_multiclass_ovo_multiview(self):
estimator = FakeNonProbaEstim()
returned_estimator = get_mc_estim(estimator, self.y, self.random_state,
multiview=True)
returned_estimator = get_mc_estim(estimator, self.random_state,
multiview=True, y=self.y, )
self.assertIsInstance(returned_estimator, MultiviewOVOWrapper)
def test_multiclass_ovr_multiview(self):
estimator = FakeProbaEstim()
returned_estimator = get_mc_estim(estimator, self.y, self.random_state,
multiview=True)
returned_estimator = get_mc_estim(estimator, self.random_state,
multiview=True, y=self.y,)
self.assertIsInstance(returned_estimator, MultiviewOVRWrapper)
class FakeMVClassifier(BaseEstimator):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment