From f7d0a4a0663433557593948572b6a9acbd81d8d2 Mon Sep 17 00:00:00 2001 From: bbauvin <baptiste.bauvin@centrale-marseille.fr> Date: Wed, 4 Oct 2017 16:24:56 -0400 Subject: [PATCH] Updated Decision trees agruments and modified gridsearch name --- Code/MonoMutliViewClassifiers/ExecClassif.py | 5 +++++ .../MonoviewClassifiers/DecisionTree.py | 22 ++++++++++++++----- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/Code/MonoMutliViewClassifiers/ExecClassif.py b/Code/MonoMutliViewClassifiers/ExecClassif.py index aefc98d2..065b9a3d 100644 --- a/Code/MonoMutliViewClassifiers/ExecClassif.py +++ b/Code/MonoMutliViewClassifiers/ExecClassif.py @@ -471,6 +471,11 @@ groupAdaboost.add_argument('--CL_Adaboost_b_est', metavar='STRING', action='stor groupDT = parser.add_argument_group('Decision Trees arguments') groupDT.add_argument('--CL_DecisionTree_depth', metavar='INT', type=int, action='store', help='Determine max depth for Decision Trees', default=3) +groupDT.add_argument('--CL_DecisionTree_criterion', metavar='STRING', action='store', + help='Determine max depth for Decision Trees', default="entropy") +groupDT.add_argument('--CL_DecisionTree_splitter', metavar='STRING', action='store', + help='Determine max depth for Decision Trees', default="random") + groupSGD = parser.add_argument_group('SGD arguments') groupSGD.add_argument('--CL_SGD_alpha', metavar='FLOAT', type=float, action='store', diff --git a/Code/MonoMutliViewClassifiers/MonoviewClassifiers/DecisionTree.py b/Code/MonoMutliViewClassifiers/MonoviewClassifiers/DecisionTree.py index 3a13c587..e3c92498 100644 --- a/Code/MonoMutliViewClassifiers/MonoviewClassifiers/DecisionTree.py +++ b/Code/MonoMutliViewClassifiers/MonoviewClassifiers/DecisionTree.py @@ -13,9 +13,12 @@ __status__ = "Prototype" # Production, Development, P def canProbas(): return True + def fit(DATASET, CLASS_LABELS, NB_CORES=1, **kwargs): maxDepth = int(kwargs['0']) - classifier = DecisionTreeClassifier(max_depth=maxDepth) + criterion = kwargs['1'] + splitter = kwargs['2'] + classifier = DecisionTreeClassifier(max_depth=maxDepth, criterion=criterion, splitter=splitter) classifier.fit(DATASET, CLASS_LABELS) return classifier @@ -25,12 +28,18 @@ def getKWARGS(kwargsList): for (kwargName, kwargValue) in kwargsList: if kwargName == "CL_DecisionTree_depth": kwargsDict['0'] = int(kwargValue) + if kwargName == "CL_DecisionTree_criterion": + kwargsDict['1'] = kwargValue + if kwargName == "CL_DecisionTree_splitter": + kwargsDict['2'] = kwargValue return kwargsDict -def gridSearch(X_train, y_train, nbFolds=4, nbCores=1, metric=["accuracy_score", None], nIter=30): +def randomizedSearch(X_train, y_train, nbFolds=4, nbCores=1, metric=["accuracy_score", None], nIter=30): pipeline_DT = Pipeline([('classifier', DecisionTreeClassifier())]) - param_DT = {"classifier__max_depth": randint(1, 30)} + param_DT = {"classifier__max_depth": randint(1, 30), + "classifier__criterion": ["gini", "entropy"], + "classifier__splitter": ["best", "random"]} metricModule = getattr(Metrics, metric[0]) if metric[1]!=None: metricKWARGS = dict((index, metricConfig) for index, metricConfig in enumerate(metric[1])) @@ -40,12 +49,13 @@ def gridSearch(X_train, y_train, nbFolds=4, nbCores=1, metric=["accuracy_score", grid_DT = RandomizedSearchCV(pipeline_DT, n_iter=nIter, param_distributions=param_DT, refit=True, n_jobs=nbCores, scoring=scorer, cv=nbFolds) DT_detector = grid_DT.fit(X_train, y_train) - desc_params = [DT_detector.best_params_["classifier__max_depth"]] + desc_params = [DT_detector.best_params_["classifier__max_depth"], DT_detector.best_params_["classifier__criterion"], + DT_detector.best_params_["classifier__splitter"]] return desc_params def getConfig(config): try: - return "\n\t\t- Decision Tree with max_depth : "+str(config[0]) + return "\n\t\t- Decision Tree with max_depth : "+str(config[0]) + ", criterion : "+config[1]+", splitter : "+config[2] except: - return "\n\t\t- Decision Tree with max_depth : "+str(config["0"]) \ No newline at end of file + return "\n\t\t- Decision Tree with max_depth : "+str(config["0"]) + ", criterion : "+config["1"]+", splitter : "+config["2"] \ No newline at end of file -- GitLab