Skip to content
Snippets Groups Projects
Commit 61be49dc authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

private_algos test_digit

parent 266a05bb
No related branches found
No related tags found
No related merge requests found
......@@ -28,4 +28,4 @@ stats_iter: 10
metrics: ["accuracy_score", "f1_score"]
metric_princ: "accuracy_score"
hps_type: "randomized_search-equiv"
hps_iter: 30
\ No newline at end of file
hps_iter: 5
\ No newline at end of file
......@@ -2,14 +2,17 @@ from imblearn.ensemble import BalancedBaggingClassifier
from sklearn.tree import DecisionTreeClassifier
from ..monoview.monoview_utils import BaseMonoviewClassifier, CustomRandint, CustomUniform
from ..utils.base import base_boosting_estimators
classifier_class_name = "ImbalanceBagging"
class ImbalanceBagging(BaseMonoviewClassifier, BalancedBaggingClassifier):
def __init__(self, random_state=None, base_estimator=DecisionTreeClassifier(max_depth=1), n_estimators=10,
sampling_strategy="auto", replacement=False,):
def __init__(self, random_state=None, base_estimator="DecisionTreeClassifier",
n_estimators=10, sampling_strategy="auto",
replacement=False, base_estimator_config=None):
base_estimator = self.get_base_estimator(base_estimator,
base_estimator_config)
super(ImbalanceBagging, self).__init__(random_state=random_state, base_estimator=base_estimator,
n_estimators=n_estimators,
sampling_strategy=sampling_strategy,
......@@ -18,7 +21,7 @@ class ImbalanceBagging(BaseMonoviewClassifier, BalancedBaggingClassifier):
self.param_names = ["n_estimators", "base_estimator", "sampling_strategy",]
self.classed_params = ["base_estimator"]
self.distribs = [CustomRandint(low=1, high=50),
[DecisionTreeClassifier(max_depth=1)],
base_boosting_estimators,
["auto"]]
self.weird_strings = {"base_estimator": "class_name"}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment