diff --git a/summit/multiview_platform/monoview_classifiers/adaboost.py b/summit/multiview_platform/monoview_classifiers/adaboost.py index 6192c99c51a170dd9def5a21cad25c797bc0aae3..561277970060a8d2430031039ddede7e2c200673 100644 --- a/summit/multiview_platform/monoview_classifiers/adaboost.py +++ b/summit/multiview_platform/monoview_classifiers/adaboost.py @@ -22,10 +22,10 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier): """ - def __init__(self, random_state=None, n_estimators=50, - estimator=None, **kwargs): + def __init__(self, random_state=None, n_estimators=50, + estimator=None, estimator_config=None, **kwargs): base_estimator = BaseMonoviewClassifier.get_base_estimator(self, - estimator) + estimator, estimator_config) AdaBoostClassifier.__init__(self, random_state=random_state, n_estimators=n_estimators, @@ -40,6 +40,7 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier): self.plotted_metric = metrics.zero_one_loss self.plotted_metric_name = "zero_one_loss" self.step_predictions = None + self.estimator_config = estimator_config def fit(self, X, y, sample_weight=None): begin = time.time()