From f7d0a4a0663433557593948572b6a9acbd81d8d2 Mon Sep 17 00:00:00 2001
From: bbauvin <baptiste.bauvin@centrale-marseille.fr>
Date: Wed, 4 Oct 2017 16:24:56 -0400
Subject: [PATCH] Updated Decision trees agruments and modified gridsearch name

---
 Code/MonoMutliViewClassifiers/ExecClassif.py  |  5 +++++
 .../MonoviewClassifiers/DecisionTree.py       | 22 ++++++++++++++-----
 2 files changed, 21 insertions(+), 6 deletions(-)

diff --git a/Code/MonoMutliViewClassifiers/ExecClassif.py b/Code/MonoMutliViewClassifiers/ExecClassif.py
index aefc98d2..065b9a3d 100644
--- a/Code/MonoMutliViewClassifiers/ExecClassif.py
+++ b/Code/MonoMutliViewClassifiers/ExecClassif.py
@@ -471,6 +471,11 @@ groupAdaboost.add_argument('--CL_Adaboost_b_est', metavar='STRING', action='stor
 groupDT = parser.add_argument_group('Decision Trees arguments')
 groupDT.add_argument('--CL_DecisionTree_depth', metavar='INT', type=int, action='store',
                      help='Determine max depth for Decision Trees', default=3)
+groupDT.add_argument('--CL_DecisionTree_criterion', metavar='STRING', action='store',
+                     help='Determine max depth for Decision Trees', default="entropy")
+groupDT.add_argument('--CL_DecisionTree_splitter', metavar='STRING', action='store',
+                     help='Determine max depth for Decision Trees', default="random")
+
 
 groupSGD = parser.add_argument_group('SGD arguments')
 groupSGD.add_argument('--CL_SGD_alpha', metavar='FLOAT', type=float, action='store',
diff --git a/Code/MonoMutliViewClassifiers/MonoviewClassifiers/DecisionTree.py b/Code/MonoMutliViewClassifiers/MonoviewClassifiers/DecisionTree.py
index 3a13c587..e3c92498 100644
--- a/Code/MonoMutliViewClassifiers/MonoviewClassifiers/DecisionTree.py
+++ b/Code/MonoMutliViewClassifiers/MonoviewClassifiers/DecisionTree.py
@@ -13,9 +13,12 @@ __status__ 	= "Prototype"                           # Production, Development, P
 def canProbas():
     return True
 
+
 def fit(DATASET, CLASS_LABELS, NB_CORES=1, **kwargs):
     maxDepth = int(kwargs['0'])
-    classifier = DecisionTreeClassifier(max_depth=maxDepth)
+    criterion = kwargs['1']
+    splitter = kwargs['2']
+    classifier = DecisionTreeClassifier(max_depth=maxDepth, criterion=criterion, splitter=splitter)
     classifier.fit(DATASET, CLASS_LABELS)
     return classifier
 
@@ -25,12 +28,18 @@ def getKWARGS(kwargsList):
     for (kwargName, kwargValue) in kwargsList:
         if kwargName == "CL_DecisionTree_depth":
             kwargsDict['0'] = int(kwargValue)
+        if kwargName == "CL_DecisionTree_criterion":
+            kwargsDict['1'] = kwargValue
+        if kwargName == "CL_DecisionTree_splitter":
+            kwargsDict['2'] = kwargValue
     return kwargsDict
 
 
-def gridSearch(X_train, y_train, nbFolds=4, nbCores=1, metric=["accuracy_score", None], nIter=30):
+def randomizedSearch(X_train, y_train, nbFolds=4, nbCores=1, metric=["accuracy_score", None], nIter=30):
     pipeline_DT = Pipeline([('classifier', DecisionTreeClassifier())])
-    param_DT = {"classifier__max_depth": randint(1, 30)}
+    param_DT = {"classifier__max_depth": randint(1, 30),
+                "classifier__criterion": ["gini", "entropy"],
+                "classifier__splitter": ["best", "random"]}
     metricModule = getattr(Metrics, metric[0])
     if metric[1]!=None:
         metricKWARGS = dict((index, metricConfig) for index, metricConfig in enumerate(metric[1]))
@@ -40,12 +49,13 @@ def gridSearch(X_train, y_train, nbFolds=4, nbCores=1, metric=["accuracy_score",
     grid_DT = RandomizedSearchCV(pipeline_DT, n_iter=nIter, param_distributions=param_DT, refit=True, n_jobs=nbCores, scoring=scorer,
                            cv=nbFolds)
     DT_detector = grid_DT.fit(X_train, y_train)
-    desc_params = [DT_detector.best_params_["classifier__max_depth"]]
+    desc_params = [DT_detector.best_params_["classifier__max_depth"], DT_detector.best_params_["classifier__criterion"],
+                   DT_detector.best_params_["classifier__splitter"]]
     return desc_params
 
 
 def getConfig(config):
     try:
-        return "\n\t\t- Decision Tree with max_depth : "+str(config[0])
+        return "\n\t\t- Decision Tree with max_depth : "+str(config[0]) + ", criterion : "+config[1]+", splitter : "+config[2]
     except:
-        return "\n\t\t- Decision Tree with max_depth : "+str(config["0"])
\ No newline at end of file
+        return "\n\t\t- Decision Tree with max_depth : "+str(config["0"]) + ", criterion : "+config["1"]+", splitter : "+config["2"]
\ No newline at end of file
-- 
GitLab