From 324236d88db554f278eefb8454eace915134a7b5 Mon Sep 17 00:00:00 2001 From: Dominique <dominique.benielli@univ-amu.fr> Date: Wed, 26 Feb 2025 16:11:33 +0100 Subject: [PATCH] Update imbalance_bagging.py --- .../monoview_classifiers/imbalance_bagging.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py b/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py index c4340420..d33f8a80 100644 --- a/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py +++ b/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py @@ -9,18 +9,18 @@ classifier_class_name = "ImbalanceBagging" class ImbalanceBagging(BaseMonoviewClassifier, BalancedBaggingClassifier): - def __init__(self, random_state=None, base_estimator="DecisionTreeClassifier", + def __init__(self, random_state=None, estimator="DecisionTreeClassifier", n_estimators=10, sampling_strategy="auto", replacement=False, base_estimator_config=None): - base_estimator = self.get_base_estimator(base_estimator, + estimator = self.get_base_estimator(estimator, base_estimator_config) - super(ImbalanceBagging, self).__init__(random_state=random_state, base_estimator=base_estimator, + super(ImbalanceBagging, self).__init__(random_state=random_state, estimator=estimator, n_estimators=n_estimators, sampling_strategy=sampling_strategy, replacement=replacement) - self.param_names = ["n_estimators", "base_estimator", "sampling_strategy",] - self.classed_params = ["base_estimator"] + self.param_names = ["n_estimators", "estimator", "sampling_strategy",] + self.classed_params = ["estimator"] self.distribs = [CustomRandint(low=1, high=50), base_boosting_estimators, ["auto"]] -- GitLab