Skip to content
Snippets Groups Projects
Commit 91053904 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Debugging gridSearch for SCM

parent 916037bc
No related branches found
No related tags found
No related merge requests found
......@@ -25,9 +25,13 @@ def fit(DATASET, CLASS_LABELS, NB_CORES=1,**kwargs):
attributeClassification = kwargs["attributeClassification"]
binaryAttributes = kwargs["binaryAttributes"]
except:
attributeClassification, binaryAttributes = transformData(DATASET)
attributeClassification, binaryAttributes, dsetFile = transformData(DATASET)
classifier = pyscm.scm.SetCoveringMachine(p=p, max_attributes=max_attrtibutes, model_type=model_type, verbose=False)
classifier.fit(binaryAttributes, CLASS_LABELS, X=None, attribute_classifications=attributeClassification, iteration_callback=None)
try:
dsetFile.close()
except:
pass
return classifier
......@@ -60,12 +64,18 @@ def gridSearch(X_train, y_train, nbFolds=4, metric=["accuracy_score", None], nIt
fold.sort()
trainIndices = [index for index in range(len(y_train)) if (index not in fold)]
attributeClassification, binaryAttributes, dsetFile = transformData(X_train[trainIndices])
try:
classifier.fit(binaryAttributes, y_train[trainIndices], X=None, attribute_classifications=attributeClassification, iteration_callback=None)
predictedLabels = classifier.predict(X_train[fold])
score = metricModule.score(y_train[fold], predictedLabels)
scores.append(score)
except:
pass
dsetFile.close()
if scores==[]:
score = baseScore
else:
score = np.mean(np.array(scores))
if isBetter=="higher" and score>baseScore:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment