Skip to content
Snippets Groups Projects
Commit 306e0b54 authored by bbauvin's avatar bbauvin
Browse files

added criterion for random forest

parent 7a959391
Branches
No related tags found
No related merge requests found
......@@ -447,6 +447,8 @@ groupRF.add_argument('--CL_RandomForest_trees', metavar='INT', type=int, action=
default=25)
groupRF.add_argument('--CL_RandomForest_max_depth', metavar='INT', type=int, action='store', help='Max depth for the trees',
default=5)
groupRF.add_argument('--CL_RandomForest_criterion', metavar='STRING', action='store', help='Criterion for the trees',
default="entropy")
groupSVMLinear = parser.add_argument_group('Linear SVM arguments')
groupSVMLinear.add_argument('--CL_SVMLinear_C', metavar='INT', type=int, action='store', help='Penalty parameter used',
......@@ -474,7 +476,7 @@ groupDT.add_argument('--CL_DecisionTree_depth', metavar='INT', type=int, action=
groupDT.add_argument('--CL_DecisionTree_criterion', metavar='STRING', action='store',
help='Determine max depth for Decision Trees', default="entropy")
groupDT.add_argument('--CL_DecisionTree_splitter', metavar='STRING', action='store',
help='Determine max depth for Decision Trees', default="random")
help='Determine criterion for Decision Trees', default="random")
groupSGD = parser.add_argument_group('SGD arguments')
......
......@@ -17,7 +17,8 @@ def canProbas():
def fit(DATASET, CLASS_LABELS, NB_CORES=1,**kwargs):
num_estimators = int(kwargs['0'])
maxDepth = int(kwargs['1'])
classifier = RandomForestClassifier(n_estimators=num_estimators, max_depth=maxDepth, n_jobs=NB_CORES)
criterion = kwargs["2"]
classifier = RandomForestClassifier(n_estimators=num_estimators, max_depth=maxDepth, criterion=criterion, n_jobs=NB_CORES)
classifier.fit(DATASET, CLASS_LABELS)
return classifier
......@@ -29,13 +30,16 @@ def getKWARGS(kwargsList):
kwargsDict['0'] = int(kwargValue)
elif kwargName == "CL_RandomForest_max_depth":
kwargsDict['1'] = kwargValue
elif kwargName == "CL_RandomForest_criterion":
kwargsDict['2'] = kwargValue
return kwargsDict
def randomizedSearch(X_train, y_train, nbFolds=4, nbCores=1, metric=["accuracy_score", None], nIter=30):
pipeline_rf = Pipeline([('classifier', RandomForestClassifier())])
param_rf = {"classifier__n_estimators": randint(1, 30),
"classifier__max_depth":randint(1, 30)}
"classifier__max_depth":randint(1, 30),
"classifier__criterion":["gini", "entropy"]}
metricModule = getattr(Metrics, metric[0])
if metric[1]!=None:
metricKWARGS = dict((index, metricConfig) for index, metricConfig in enumerate(metric[1]))
......@@ -46,12 +50,13 @@ def randomizedSearch(X_train, y_train, nbFolds=4, nbCores=1, metric=["accuracy_s
rf_detector = grid_rf.fit(X_train, y_train)
desc_estimators = [rf_detector.best_params_["classifier__n_estimators"],
rf_detector.best_params_["classifier__max_depth"]]
rf_detector.best_params_["classifier__max_depth"],
rf_detector.best_params_["classifier__criterion"]]
return desc_estimators
def getConfig(config):
try:
return "\n\t\t- Random Forest with num_esimators : "+str(config[0])+", max_depth : "+str(config[1])
return "\n\t\t- Random Forest with num_esimators : "+str(config[0])+", max_depth : "+str(config[1])+ ", criterion : "+config[2]
except:
return "\n\t\t- Random Forest with num_esimators : "+str(config["0"])+", max_depth : "+str(config["1"])
return "\n\t\t- Random Forest with num_esimators : "+str(config["0"])+", max_depth : "+str(config["1"])+ ", criterion : "+config["2"]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment