diff --git a/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/Adaboost.py b/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/Adaboost.py index 9980f721d6226110f532525e6fe78a258f4ac107..05f6545b493f79484e05d43d58c9b58bad793943 100644 --- a/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/Adaboost.py +++ b/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/Adaboost.py @@ -10,8 +10,8 @@ __status__ = "Prototype" # Production, Development, Prototype class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier): - def __init__(self, random_state, n_estimators=10, - base_estimator=DecisionTreeClassifier(), **kwargs): + def __init__(self, random_state=None, n_estimators=10, + base_estimator=None, **kwargs): super(Adaboost, self).__init__( random_state=random_state, n_estimators=n_estimators, @@ -19,7 +19,7 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier): ) self.param_names = ["n_estimators", "base_estimator"] self.classed_params = ["base_estimator"] - self.distribs = [CustomRandint(low=1, high=500), [DecisionTreeClassifier()]] + self.distribs = [CustomRandint(low=1, high=500), [None]] self.weird_strings = {"base_estimator":"class_name"} def canProbas(self): @@ -31,7 +31,7 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier): paramsSet = [] for _ in range(nIter): paramsSet.append({"n_estimators": self.random_state.randint(1, 150), - "base_estimator": DecisionTreeClassifier()}) + "base_estimator": None}) return paramsSet def getInterpret(self, directory):