Skip to content
Snippets Groups Projects
Commit 19b17556 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Mumbo cbase classifier adaptation

parent 09700327
No related branches found
No related tags found
No related merge requests found
......@@ -6,6 +6,7 @@ from ..multiview.multiview_utils import BaseMultiviewClassifier
from ..utils.hyper_parameter_search import CustomRandint
from ..utils.dataset import get_examples_views_indices
from ..utils.base import base_boosting_estimators
from .. import monoview_classifiers
classifier_class_name = "Mumbo"
......@@ -24,6 +25,26 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier):
self.distribs = [base_boosting_estimators,
CustomRandint(5,200), [random_state], ["edge", "error"]]
def set_params(self, base_estimator=None, **params):
"""
Sets the base estimator from a dict.
:param base_estimator:
:param params:
:return:
"""
if base_estimator is None:
self.base_estimator = DecisionTreeClassifier()
elif isinstance(base_estimator, dict):
estim_name = next(iter(base_estimator))
estim_module = getattr(monoview_classifiers, estim_name)
estim_class = getattr(estim_module,
estim_module.classifier_class_name)
self.base_estimator = estim_class(**base_estimator[estim_name])
MumboClassifier.set_params(self, **params)
else:
MumboClassifier.set_params(self, base_estimator=base_estimator, **params)
def fit(self, X, y, train_indices=None, view_indices=None):
train_indices, view_indices = get_examples_views_indices(X,
train_indices,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment