From f43706305e23d04185a53c71994f182633212b9f Mon Sep 17 00:00:00 2001 From: bbauvin <baptiste.bauvin@centrale-marseille.fr> Date: Thu, 5 Oct 2017 16:38:12 -0400 Subject: [PATCH] Corrected SCM crossval --- .../MonoviewClassifiers/SCM.py | 76 +++++++++---------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/Code/MonoMutliViewClassifiers/MonoviewClassifiers/SCM.py b/Code/MonoMutliViewClassifiers/MonoviewClassifiers/SCM.py index 62c5b170..065ee1b0 100644 --- a/Code/MonoMutliViewClassifiers/MonoviewClassifiers/SCM.py +++ b/Code/MonoMutliViewClassifiers/MonoviewClassifiers/SCM.py @@ -68,44 +68,44 @@ def randomizedSearch(X_train, y_train, randomState, KFolds=None, metric=["accura baseScore = 1000.0 isBetter = "lower" config = [] - # for iterIndex in range(nIter): - max_attributes = randomState.randint(1, 20) - p = randomState.random_sample() - model = randomState.choice(["conjunction", "disjunction"]) - classifier = pyscm.scm.SetCoveringMachine(p=p, max_attributes=max_attributes, model_type=model, verbose=False) - # if nbFolds != 1: - # kFolds = DB.getKFoldIndices(nbFolds, y_train, len(set(y_train)), range(len(y_train)), randomState) - # else: - # kFolds = [[], range(len(y_train))] - scores = [] - KFolds = KFolds.split(X_train, y_train) - for foldIdx, (trainIndices, testIndices) in enumerate(KFolds): - # if fold != range(len(y_train)): - # fold.sort() - # trainIndices = [index for index in range(len(y_train)) if (index not in fold)] - attributeClassification, binaryAttributes, dsetFile, name = transformData(X_train[trainIndices]) - try: - classifier.fit(binaryAttributes, y_train[trainIndices], X=None, - attribute_classifications=attributeClassification, iteration_callback=None) - - predictedLabels = classifier.predict(X_train[testIndices]) - score = metricModule.score(y_train[testIndices], predictedLabels) - scores.append(score) - except: - pass - dsetFile.close() - os.remove(name) - if scores==[]: - score = baseScore - else: - score = np.mean(np.array(scores)) - - if isBetter=="higher" and score>baseScore: - baseScore = score - config = [max_attributes, p, model] - if isBetter=="lower" and score<baseScore: - baseScore = score - config = [max_attributes, p, model] + for iterIndex in range(nIter): + max_attributes = randomState.randint(1, 20) + p = randomState.random_sample() + model = randomState.choice(["conjunction", "disjunction"]) + classifier = pyscm.scm.SetCoveringMachine(p=p, max_attributes=max_attributes, model_type=model, verbose=False) + # if nbFolds != 1: + # kFolds = DB.getKFoldIndices(nbFolds, y_train, len(set(y_train)), range(len(y_train)), randomState) + # else: + # kFolds = [[], range(len(y_train))] + scores = [] + KFolds = KFolds.split(X_train, y_train) + for foldIdx, (trainIndices, testIndices) in enumerate(KFolds): + # if fold != range(len(y_train)): + # fold.sort() + # trainIndices = [index for index in range(len(y_train)) if (index not in fold)] + attributeClassification, binaryAttributes, dsetFile, name = transformData(X_train[trainIndices]) + try: + classifier.fit(binaryAttributes, y_train[trainIndices], X=None, + attribute_classifications=attributeClassification, iteration_callback=None) + + predictedLabels = classifier.predict(X_train[testIndices]) + score = metricModule.score(y_train[testIndices], predictedLabels) + scores.append(score) + except: + pass + dsetFile.close() + os.remove(name) + if scores==[]: + score = baseScore + else: + score = np.mean(np.array(scores)) + + if isBetter=="higher" and score > baseScore: + baseScore = score + config = [max_attributes, p, model] + if isBetter=="lower" and score < baseScore: + baseScore = score + config = [max_attributes, p, model] assert config!=[], "No good configuration found for SCM" return config -- GitLab