From 3361b380a40ed491303c39644c7d51de4e8ea3b4 Mon Sep 17 00:00:00 2001 From: Dominique <dominique.benielli@univ-amu.fr> Date: Wed, 26 Feb 2025 18:10:38 +0100 Subject: [PATCH] Update adaboost.py --- summit/multiview_platform/monoview_classifiers/adaboost.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/summit/multiview_platform/monoview_classifiers/adaboost.py b/summit/multiview_platform/monoview_classifiers/adaboost.py index 6192c99c..56127797 100644 --- a/summit/multiview_platform/monoview_classifiers/adaboost.py +++ b/summit/multiview_platform/monoview_classifiers/adaboost.py @@ -22,10 +22,10 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier): """ - def __init__(self, random_state=None, n_estimators=50, - estimator=None, **kwargs): + def __init__(self, random_state=None, n_estimators=50, + estimator=None, estimator_config=None, **kwargs): base_estimator = BaseMonoviewClassifier.get_base_estimator(self, - estimator) + estimator, estimator_config) AdaBoostClassifier.__init__(self, random_state=random_state, n_estimators=n_estimators, @@ -40,6 +40,7 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier): self.plotted_metric = metrics.zero_one_loss self.plotted_metric_name = "zero_one_loss" self.step_predictions = None + self.estimator_config = estimator_config def fit(self, X, y, sample_weight=None): begin = time.time() -- GitLab