Skip to content
Snippets Groups Projects
Commit e17f19b0 authored by Kossi Kossivi's avatar Kossi Kossivi
Browse files

Fix initialization bug (related to base_estimator attribute) in Mumbo and Mucombo wrappers

Also apply minor corrections to feature_importance and Imbalance Bagging.
parent c4a9fe53
No related branches found
No related tags found
No related merge requests found
......@@ -82,7 +82,7 @@ class ScmBagging(RandomScmClassifier, BaseMonoviewClassifier):
max_samples=max_samples,
max_features=max_features,
max_rules=max_rules,
p_options=p_options,
p=p_options,
model_type=model_type,
random_state=random_state)
self.param_names = ["n_estimators", "max_rules", "max_samples", "max_features", "model_type", "p_options", "random_state"]
......@@ -94,7 +94,7 @@ class ScmBagging(RandomScmClassifier, BaseMonoviewClassifier):
def set_params(self, p_options=[0.316], **kwargs):
if not isinstance(p_options, list):
p_options = [p_options]
kwargs["p_options"] = p_options
kwargs["p"] = p_options
for parameter, value in iteritems(kwargs):
setattr(self, parameter, value)
return self
......
......@@ -7,20 +7,20 @@ from ..utils.hyper_parameter_search import CustomRandint
from ..utils.dataset import get_samples_views_indices
from ..utils.base import base_boosting_estimators
classifier_class_name = "MuCumbo"
classifier_class_name = "MuCombo"
class MuCombo(BaseMultiviewClassifier, MuComboClassifier):
def __init__(self, estimator=None,
def __init__(self, base_estimator=None,
n_estimators=50,
random_state=None,**kwargs):
BaseMultiviewClassifier.__init__(self, random_state)
estimator = self.set_base_estim_from_dict(estimator, **kwargs)
MuComboClassifier.__init__(self, estimator=estimator,
base_estimator = self.set_base_estim_from_dict(base_estimator, **kwargs)
MuComboClassifier.__init__(self, base_estimator=base_estimator,
n_estimators=n_estimators,
random_state=random_state,)
self.param_names = ["estimator", "n_estimators", "random_state",]
self.param_names = ["base_estimator", "n_estimators", "random_state",]
self.distribs = [base_boosting_estimators,
CustomRandint(5,200), [random_state],]
......@@ -43,6 +43,12 @@ class MuCombo(BaseMultiviewClassifier, MuComboClassifier):
view_indices=view_indices)
return MuComboClassifier.predict(self, numpy_X)
def get_interpretation(self, directory, base_file_name, labels,
multiclass=False):
def get_interpretation(self, directory, base_file_name, y_test, feature_ids,
multi_class=False):
return ""
def set_base_estim_from_dict(self, dict):
key, args = list(dict.items())[0]
if key == "decision_tree":
return DecisionTreeClassifier(**args)
\ No newline at end of file
......@@ -13,19 +13,20 @@ from .. import monoview_classifiers
classifier_class_name = "Mumbo"
class Mumbo(BaseMultiviewClassifier, MumboClassifier):
def __init__(self, estimator=None,
def __init__(self, base_estimator=None,
n_estimators=50,
random_state=None,
best_view_mode="edge", **kwargs):
BaseMultiviewClassifier.__init__(self, random_state)
base_estimator = self.set_base_estim_from_dict(estimator, **kwargs)
MumboClassifier.__init__(self, base_estimator=estimator,
base_estimator = self.set_base_estim_from_dict(base_estimator)
MumboClassifier.__init__(self, base_estimator=base_estimator,
n_estimators=n_estimators,
random_state=random_state,
best_view_mode=best_view_mode)
self.param_names = ["estimator", "n_estimators", "random_state", "best_view_mode"]
self.param_names = ["base_estimator", "n_estimators", "random_state", "best_view_mode"]
self.distribs = [base_boosting_estimators,
CustomRandint(5, 200), [random_state], ["edge", "error"]]
......@@ -42,8 +43,7 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier):
self.base_estimator = self.set_base_estim_from_dict(estimator)
MumboClassifier.set_params(self, **params)
else:
MumboClassifier.set_params(self, estimator=estimator, **params)
MumboClassifier.set_params(self, base_estimator=estimator, **params)
def fit(self, X, y, train_indices=None, view_indices=None):
train_indices, view_indices = get_samples_views_indices(X,
......@@ -69,7 +69,8 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier):
view_indices=view_indices)
return MumboClassifier.predict(self, numpy_X)
def get_interpretation(self, directory, base_file_name, labels, multiclass=False):
def get_interpretation(self, directory, base_file_name, y_test, feature_ids,
multi_class=False):
self.view_importances = np.zeros(len(self.used_views))
self.feature_importances_ = [np.zeros(view_shape)
for view_shape in self.view_shapes]
......@@ -101,5 +102,11 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier):
self.view_names[view_index],
self.view_importances[view_index])
interpret_string += "\n The boosting process selected views : \n" + ", ".join(map(str, self.best_views_))
interpret_string+="\n\n With estimator weights : \n"+ "\n".join(map(str,self.estimator_weights_/np.sum(self.estimator_weights_)))
interpret_string += "\n\n With estimator weights : \n" + "\n".join(
map(str, self.estimator_weights_ / np.sum(self.estimator_weights_)))
return interpret_string
def set_base_estim_from_dict(self, dict):
key, args = list(dict.items())[0]
if key == "decision_tree":
return DecisionTreeClassifier(**args)
......@@ -44,7 +44,7 @@ def get_feature_importances(result, feature_ids=None, view_names=None,):
v_feature_id]
feature_importances["mv"] = pd.DataFrame(index=feat_ids)
if hasattr(classifier_result.clf, 'feature_importances_'):
feature_importances["mv"][classifier_result.classifier_name] = classifier_result.clf.feature_importances_
feature_importances["mv"][classifier_result.classifier_name] = np.concatenate(classifier_result.clf.feature_importances_)
return feature_importances
......
......@@ -106,9 +106,9 @@ class Dataset():
return concat_views, view_limits
def select_labels(self, selected_label_names):
selected_labels = [self.get_label_names(decode=True).index(label_name.decode())
selected_labels = [self.get_label_names().index(label_name.decode())
if isinstance(label_name, bytes)
else self.get_label_names(decode=True).index(label_name)
else self.get_label_names().index(label_name)
for label_name in selected_label_names]
selected_indices = np.array([index
for index, label in
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment