Skip to content
Snippets Groups Projects
Commit c91f1192 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Update mucombo

parent 89287be0
No related branches found
No related tags found
No related merge requests found
from sklearn.tree import DecisionTreeClassifier from sklearn.tree import DecisionTreeClassifier
from multimodal.boosting.cumbo import MuCumboClassifier from multimodal.boosting.combo import MuComboClassifier
from ..multiview.multiview_utils import BaseMultiviewClassifier from ..multiview.multiview_utils import BaseMultiviewClassifier
from ..utils.hyper_parameter_search import CustomRandint from ..utils.hyper_parameter_search import CustomRandint
from ..utils.dataset import get_samples_views_indices from ..utils.dataset import get_samples_views_indices
...@@ -10,14 +10,14 @@ from ..utils.base import base_boosting_estimators ...@@ -10,14 +10,14 @@ from ..utils.base import base_boosting_estimators
classifier_class_name = "MuCumbo" classifier_class_name = "MuCumbo"
class MuCumbo(BaseMultiviewClassifier, MuCumboClassifier): class MuCumbo(BaseMultiviewClassifier, MuComboClassifier):
def __init__(self, base_estimator=None, def __init__(self, base_estimator=None,
n_estimators=50, n_estimators=50,
random_state=None,**kwargs): random_state=None,**kwargs):
BaseMultiviewClassifier.__init__(self, random_state) BaseMultiviewClassifier.__init__(self, random_state)
base_estimator = self.set_base_estim_from_dict(base_estimator, **kwargs) base_estimator = self.set_base_estim_from_dict(base_estimator, **kwargs)
MuCumboClassifier.__init__(self, base_estimator=base_estimator, MuComboClassifier.__init__(self, base_estimator=base_estimator,
n_estimators=n_estimators, n_estimators=n_estimators,
random_state=random_state,) random_state=random_state,)
self.param_names = ["base_estimator", "n_estimators", "random_state",] self.param_names = ["base_estimator", "n_estimators", "random_state",]
...@@ -31,7 +31,7 @@ class MuCumbo(BaseMultiviewClassifier, MuCumboClassifier): ...@@ -31,7 +31,7 @@ class MuCumbo(BaseMultiviewClassifier, MuCumboClassifier):
self.used_views = view_indices self.used_views = view_indices
numpy_X, view_limits = X.to_numpy_array(sample_indices=train_indices, numpy_X, view_limits = X.to_numpy_array(sample_indices=train_indices,
view_indices=view_indices) view_indices=view_indices)
return MuCumboClassifier.fit(self, numpy_X, y[train_indices], return MuComboClassifier.fit(self, numpy_X, y[train_indices],
view_limits) view_limits)
def predict(self, X, sample_indices=None, view_indices=None): def predict(self, X, sample_indices=None, view_indices=None):
...@@ -41,7 +41,7 @@ class MuCumbo(BaseMultiviewClassifier, MuCumboClassifier): ...@@ -41,7 +41,7 @@ class MuCumbo(BaseMultiviewClassifier, MuCumboClassifier):
self._check_views(view_indices) self._check_views(view_indices)
numpy_X, view_limits = X.to_numpy_array(sample_indices=sample_indices, numpy_X, view_limits = X.to_numpy_array(sample_indices=sample_indices,
view_indices=view_indices) view_indices=view_indices)
return MuCumboClassifier.predict(self, numpy_X) return MuComboClassifier.predict(self, numpy_X)
def get_interpretation(self, directory, base_file_name, labels, def get_interpretation(self, directory, base_file_name, labels,
multiclass=False): multiclass=False):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment