diff --git a/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/Adaboost.py b/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/Adaboost.py index 7caafade8b8576591a1ef20f970c491c3e4af7b7..50e61763e0cc3c7c5174509a23b2336eb498a07f 100644 --- a/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/Adaboost.py +++ b/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/Adaboost.py @@ -17,6 +17,62 @@ from ..utils.Interpret import getFeatureImportance __author__ = "Baptiste Bauvin" __status__ = "Prototype" # Production, Development, Prototype +class Adaboost(AdaBoostClassifier): + + def __init__(self, random_state, **kwargs): + super(AdaBoostClassifier, self).__init__( + n_estimators=kwargs['n_estimators'], + base_estimator=kwargs['base_estimator'], + random_state=random_state) + + def canProbas(self): + """Used to know if the classifier can return label probabilities""" + return True + + def paramsToSrt(self, nIter=1): + """Used for weighted linear early fusion to generate random search sets""" + paramsSet = [] + for _ in range(nIter): + paramsSet.append({"n_estimators": self.random_state.randint(1, 150), + "base_estimator": DecisionTreeClassifier()}) + return paramsSet + + def getKWARGS(self, args): + """Used to format kwargs for the parsed args""" + kwargsDict = {} + kwargsDict['n_estimators'] = args.Ada_n_est + kwargsDict['base_estimator'] = DecisionTreeClassifier() # args.Ada_b_est + return kwargsDict + + def genPipeline(self): + return Pipeline([('classifier', AdaBoostClassifier())]) + + def genParamsDict(self, randomState): + return {"classifier__n_estimators": np.arange(150) + 1, + "classifier__base_estimator": [DecisionTreeClassifier()]} + + def genBestParams(self, detector): + return {"n_estimators": detector.best_params_["classifier__n_estimators"], + "base_estimator": detector.best_params_["classifier__base_estimator"]} + + def genParamsFromDetector(self, detector): + nIter = len(detector.cv_results_['param_classifier__n_estimators']) + return [("baseEstimators", np.array(["DecisionTree" for _ in range(nIter)])), + ("nEstimators", np.array(detector.cv_results_['param_classifier__n_estimators']))] + + def getConfig(self, config): + if type(config) is not dict: # Used in late fusion when config is a classifier + return "\n\t\t- Adaboost with num_esimators : " + str(config.n_estimators) + ", base_estimators : " + str( + config.base_estimator) + else: + return "\n\t\t- Adaboost with n_estimators : " + str(config["n_estimators"]) + ", base_estimator : " + str( + config["base_estimator"]) + + + def getInterpret(self, classifier, directory): + interpretString = getFeatureImportance(classifier, directory) + return interpretString + def canProbas(): return True