Skip to content
Snippets Groups Projects
Commit 81ea004a authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Mumbo dict base param

parent b88f3845
No related branches found
No related tags found
No related merge requests found
from sklearn.tree import DecisionTreeClassifier
from sklearn.base import BaseEstimator
import numpy as np
from multimodal.boosting.mumbo import MumboClassifier
......@@ -17,6 +18,7 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier):
random_state=None,
best_view_mode="edge"):
BaseMultiviewClassifier.__init__(self, random_state)
base_estimator = self.set_base_estim_from_dict(base_estimator)
MumboClassifier.__init__(self, base_estimator=base_estimator,
n_estimators=n_estimators,
random_state=random_state,
......@@ -25,6 +27,23 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier):
self.distribs = [base_boosting_estimators,
CustomRandint(5,200), [random_state], ["edge", "error"]]
def set_base_estim_from_dict(self, base_estim_dict):
if base_estim_dict is None:
base_estimator = DecisionTreeClassifier()
elif isinstance(base_estim_dict, dict):
estim_name = next(iter(base_estim_dict))
estim_module = getattr(monoview_classifiers, estim_name)
estim_class = getattr(estim_module,
estim_module.classifier_class_name)
base_estimator = estim_class(**base_estim_dict[estim_name])
elif isinstance(base_estim_dict, BaseEstimator):
base_estimator = base_estim_dict
else:
raise ValueError("base_estimator should be either None, a dictionary"
" or a BaseEstimator child object, "
"here it is {}".format(type(base_estim_dict)))
return base_estimator
def set_params(self, base_estimator=None, **params):
"""
Sets the base estimator from a dict.
......@@ -35,11 +54,7 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier):
if base_estimator is None:
self.base_estimator = DecisionTreeClassifier()
elif isinstance(base_estimator, dict):
estim_name = next(iter(base_estimator))
estim_module = getattr(monoview_classifiers, estim_name)
estim_class = getattr(estim_module,
estim_module.classifier_class_name)
self.base_estimator = estim_class(**base_estimator[estim_name])
self.base_estimator = self.set_base_estim_from_dict(base_estimator)
MumboClassifier.set_params(self, **params)
else:
MumboClassifier.set_params(self, base_estimator=base_estimator, **params)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment