Skip to content
Snippets Groups Projects
Commit 46bd637c authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

baseboosting estim in mumbo

parent a5e80f70
No related branches found
No related tags found
No related merge requests found
...@@ -151,7 +151,7 @@ class LateFusionClassifier(BaseMultiviewClassifier, BaseFusionClassifier): ...@@ -151,7 +151,7 @@ class LateFusionClassifier(BaseMultiviewClassifier, BaseFusionClassifier):
{classifier_name: self.classifier_configs[classifier_name]} for {classifier_name: self.classifier_configs[classifier_name]} for
classifier_name in self.classifiers_names] classifier_name in self.classifiers_names]
elif self.classifier_configs is None: elif self.classifier_configs is None:
self.classifier_configs = [None for _ in range(nb_clfs)] self.classifier_configs = [{} for _ in range(nb_clfs)]
# def verif_clf_views(self, classifier_names, nb_view): # def verif_clf_views(self, classifier_names, nb_view):
# if classifier_names is None: # if classifier_names is None:
......
...@@ -5,6 +5,7 @@ from multimodal.boosting.mumbo import MumboClassifier ...@@ -5,6 +5,7 @@ from multimodal.boosting.mumbo import MumboClassifier
from ..multiview.multiview_utils import BaseMultiviewClassifier from ..multiview.multiview_utils import BaseMultiviewClassifier
from ..utils.hyper_parameter_search import CustomRandint from ..utils.hyper_parameter_search import CustomRandint
from ..utils.dataset import get_examples_views_indices from ..utils.dataset import get_examples_views_indices
from ..utils.base import base_boosting_estimators
classifier_class_name = "Mumbo" classifier_class_name = "Mumbo"
...@@ -20,10 +21,7 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier): ...@@ -20,10 +21,7 @@ class Mumbo(BaseMultiviewClassifier, MumboClassifier):
random_state=random_state, random_state=random_state,
best_view_mode=best_view_mode) best_view_mode=best_view_mode)
self.param_names = ["base_estimator", "n_estimators", "random_state", "best_view_mode"] self.param_names = ["base_estimator", "n_estimators", "random_state", "best_view_mode"]
self.distribs = [[DecisionTreeClassifier(max_depth=1), self.distribs = [base_boosting_estimators,
DecisionTreeClassifier(max_depth=2),
DecisionTreeClassifier(max_depth=3),
DecisionTreeClassifier(max_depth=4)],
CustomRandint(5,200), [random_state], ["edge", "error"]] CustomRandint(5,200), [random_state], ["edge", "error"]]
def fit(self, X, y, train_indices=None, view_indices=None): def fit(self, X, y, train_indices=None, view_indices=None):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment