diff --git a/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py b/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py index c434042044e96cd9cb5947b70aee3c4ace77647f..d33f8a809b508be6c5a13c5b3819c0ff80deb9f8 100644 --- a/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py +++ b/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py @@ -9,18 +9,18 @@ classifier_class_name = "ImbalanceBagging" class ImbalanceBagging(BaseMonoviewClassifier, BalancedBaggingClassifier): - def __init__(self, random_state=None, base_estimator="DecisionTreeClassifier", + def __init__(self, random_state=None, estimator="DecisionTreeClassifier", n_estimators=10, sampling_strategy="auto", replacement=False, base_estimator_config=None): - base_estimator = self.get_base_estimator(base_estimator, + estimator = self.get_base_estimator(estimator, base_estimator_config) - super(ImbalanceBagging, self).__init__(random_state=random_state, base_estimator=base_estimator, + super(ImbalanceBagging, self).__init__(random_state=random_state, estimator=estimator, n_estimators=n_estimators, sampling_strategy=sampling_strategy, replacement=replacement) - self.param_names = ["n_estimators", "base_estimator", "sampling_strategy",] - self.classed_params = ["base_estimator"] + self.param_names = ["n_estimators", "estimator", "sampling_strategy",] + self.classed_params = ["estimator"] self.distribs = [CustomRandint(low=1, high=50), base_boosting_estimators, ["auto"]]