Skip to content
Snippets Groups Projects
Commit f7d0a4a0 authored by bbauvin's avatar bbauvin
Browse files

Updated Decision trees agruments and modified gridsearch name

parent ce891071
Branches
No related tags found
No related merge requests found
...@@ -471,6 +471,11 @@ groupAdaboost.add_argument('--CL_Adaboost_b_est', metavar='STRING', action='stor ...@@ -471,6 +471,11 @@ groupAdaboost.add_argument('--CL_Adaboost_b_est', metavar='STRING', action='stor
groupDT = parser.add_argument_group('Decision Trees arguments') groupDT = parser.add_argument_group('Decision Trees arguments')
groupDT.add_argument('--CL_DecisionTree_depth', metavar='INT', type=int, action='store', groupDT.add_argument('--CL_DecisionTree_depth', metavar='INT', type=int, action='store',
help='Determine max depth for Decision Trees', default=3) help='Determine max depth for Decision Trees', default=3)
groupDT.add_argument('--CL_DecisionTree_criterion', metavar='STRING', action='store',
help='Determine max depth for Decision Trees', default="entropy")
groupDT.add_argument('--CL_DecisionTree_splitter', metavar='STRING', action='store',
help='Determine max depth for Decision Trees', default="random")
groupSGD = parser.add_argument_group('SGD arguments') groupSGD = parser.add_argument_group('SGD arguments')
groupSGD.add_argument('--CL_SGD_alpha', metavar='FLOAT', type=float, action='store', groupSGD.add_argument('--CL_SGD_alpha', metavar='FLOAT', type=float, action='store',
......
...@@ -13,9 +13,12 @@ __status__ = "Prototype" # Production, Development, P ...@@ -13,9 +13,12 @@ __status__ = "Prototype" # Production, Development, P
def canProbas(): def canProbas():
return True return True
def fit(DATASET, CLASS_LABELS, NB_CORES=1, **kwargs): def fit(DATASET, CLASS_LABELS, NB_CORES=1, **kwargs):
maxDepth = int(kwargs['0']) maxDepth = int(kwargs['0'])
classifier = DecisionTreeClassifier(max_depth=maxDepth) criterion = kwargs['1']
splitter = kwargs['2']
classifier = DecisionTreeClassifier(max_depth=maxDepth, criterion=criterion, splitter=splitter)
classifier.fit(DATASET, CLASS_LABELS) classifier.fit(DATASET, CLASS_LABELS)
return classifier return classifier
...@@ -25,12 +28,18 @@ def getKWARGS(kwargsList): ...@@ -25,12 +28,18 @@ def getKWARGS(kwargsList):
for (kwargName, kwargValue) in kwargsList: for (kwargName, kwargValue) in kwargsList:
if kwargName == "CL_DecisionTree_depth": if kwargName == "CL_DecisionTree_depth":
kwargsDict['0'] = int(kwargValue) kwargsDict['0'] = int(kwargValue)
if kwargName == "CL_DecisionTree_criterion":
kwargsDict['1'] = kwargValue
if kwargName == "CL_DecisionTree_splitter":
kwargsDict['2'] = kwargValue
return kwargsDict return kwargsDict
def gridSearch(X_train, y_train, nbFolds=4, nbCores=1, metric=["accuracy_score", None], nIter=30): def randomizedSearch(X_train, y_train, nbFolds=4, nbCores=1, metric=["accuracy_score", None], nIter=30):
pipeline_DT = Pipeline([('classifier', DecisionTreeClassifier())]) pipeline_DT = Pipeline([('classifier', DecisionTreeClassifier())])
param_DT = {"classifier__max_depth": randint(1, 30)} param_DT = {"classifier__max_depth": randint(1, 30),
"classifier__criterion": ["gini", "entropy"],
"classifier__splitter": ["best", "random"]}
metricModule = getattr(Metrics, metric[0]) metricModule = getattr(Metrics, metric[0])
if metric[1]!=None: if metric[1]!=None:
metricKWARGS = dict((index, metricConfig) for index, metricConfig in enumerate(metric[1])) metricKWARGS = dict((index, metricConfig) for index, metricConfig in enumerate(metric[1]))
...@@ -40,12 +49,13 @@ def gridSearch(X_train, y_train, nbFolds=4, nbCores=1, metric=["accuracy_score", ...@@ -40,12 +49,13 @@ def gridSearch(X_train, y_train, nbFolds=4, nbCores=1, metric=["accuracy_score",
grid_DT = RandomizedSearchCV(pipeline_DT, n_iter=nIter, param_distributions=param_DT, refit=True, n_jobs=nbCores, scoring=scorer, grid_DT = RandomizedSearchCV(pipeline_DT, n_iter=nIter, param_distributions=param_DT, refit=True, n_jobs=nbCores, scoring=scorer,
cv=nbFolds) cv=nbFolds)
DT_detector = grid_DT.fit(X_train, y_train) DT_detector = grid_DT.fit(X_train, y_train)
desc_params = [DT_detector.best_params_["classifier__max_depth"]] desc_params = [DT_detector.best_params_["classifier__max_depth"], DT_detector.best_params_["classifier__criterion"],
DT_detector.best_params_["classifier__splitter"]]
return desc_params return desc_params
def getConfig(config): def getConfig(config):
try: try:
return "\n\t\t- Decision Tree with max_depth : "+str(config[0]) return "\n\t\t- Decision Tree with max_depth : "+str(config[0]) + ", criterion : "+config[1]+", splitter : "+config[2]
except: except:
return "\n\t\t- Decision Tree with max_depth : "+str(config["0"]) return "\n\t\t- Decision Tree with max_depth : "+str(config["0"]) + ", criterion : "+config["1"]+", splitter : "+config["2"]
\ No newline at end of file \ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment