From 846084bbe287002b0a12ad20a374639e2e02ccf2 Mon Sep 17 00:00:00 2001 From: Florence <florence@Stomatopoda.local> Date: Fri, 25 Mar 2022 09:48:31 -0400 Subject: [PATCH] New imbalance_bagging.py --- .../monoview_classifiers/imbalance_bagging.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 summit/multiview_platform/monoview_classifiers/imbalance_bagging.py diff --git a/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py b/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py new file mode 100644 index 00000000..c4340420 --- /dev/null +++ b/summit/multiview_platform/monoview_classifiers/imbalance_bagging.py @@ -0,0 +1,31 @@ +from imblearn.ensemble import BalancedBaggingClassifier +from sklearn.tree import DecisionTreeClassifier + +from ..monoview.monoview_utils import BaseMonoviewClassifier +from ..utils.base import base_boosting_estimators +from ..utils.hyper_parameter_search import CustomRandint, CustomUniform + +classifier_class_name = "ImbalanceBagging" + +class ImbalanceBagging(BaseMonoviewClassifier, BalancedBaggingClassifier): + + def __init__(self, random_state=None, base_estimator="DecisionTreeClassifier", + n_estimators=10, sampling_strategy="auto", + replacement=False, base_estimator_config=None): + base_estimator = self.get_base_estimator(base_estimator, + base_estimator_config) + super(ImbalanceBagging, self).__init__(random_state=random_state, base_estimator=base_estimator, + n_estimators=n_estimators, + sampling_strategy=sampling_strategy, + replacement=replacement) + + self.param_names = ["n_estimators", "base_estimator", "sampling_strategy",] + self.classed_params = ["base_estimator"] + self.distribs = [CustomRandint(low=1, high=50), + base_boosting_estimators, + ["auto"]] + self.weird_strings = {"base_estimator": "class_name"} + self.base_estimator_config = base_estimator_config + + + -- GitLab