Skip to content
Snippets Groups Projects
Commit 9804bd57 authored by bbauvin's avatar bbauvin
Browse files

Added multiclass test indices

parent e8f6ee40
No related branches found
No related tags found
No related merge requests found
......@@ -845,7 +845,7 @@ def execOneBenchmark(coreIndex=-1, LABELS_DICTIONARY=None, directory=None, class
except OSError as exc:
if exc.errno != errno.EEXIST:
raise
trainIndices, testIndices = classificationIndices
trainIndices = classificationIndices[0]
trainLabels = labels[trainIndices]
np.savetxt(directory + "train_labels.csv", trainLabels, delimiter=",")
resultsMonoview = []
......@@ -883,7 +883,7 @@ def execOneBenchmark_multicore(nbCores=-1, LABELS_DICTIONARY=None, directory=Non
except OSError as exc:
if exc.errno != errno.EEXIST:
raise
trainIndices, testIndices = classificationIndices
trainIndices = classificationIndices[0]
trainLabels = labels[trainIndices]
np.savetxt(directory + "train_labels.csv", trainLabels, delimiter=",")
np.savetxt(directory + "train_indices.csv", classificationIndices[0], delimiter=",")
......@@ -929,7 +929,7 @@ def execOneBenchmarkMonoCore(DATASET=None, LABELS_DICTIONARY=None, directory=Non
except OSError as exc:
if exc.errno != errno.EEXIST:
raise
trainIndices, testIndices = classificationIndices
trainIndices = classificationIndices[0]
trainLabels = labels[trainIndices]
np.savetxt(directory + "train_labels.csv", trainLabels, delimiter=",")
resultsMonoview = []
......@@ -1023,7 +1023,7 @@ def execClassif(arguments):
classificationIndices = execution.genSplits(DATASET.get("Labels").value, args.CL_split, statsIterRandomStates)
multiclassLabels, labelsCombinations, oldIndicesMulticlass = Multiclass.genMulticlassLabels(DATASET.get("Labels").value, multiclassMethod, classificationIndices)
multiclassLabels, labelsCombinations, indicesMulticlass = Multiclass.genMulticlassLabels(DATASET.get("Labels").value, multiclassMethod, classificationIndices)
kFolds = execution.genKFolds(statsIter, args.CL_nbFolds, statsIterRandomStates)
......@@ -1060,7 +1060,7 @@ def execClassif(arguments):
initKWARGS)
directories = execution.genDirecortiesNames(directory, statsIter)
benchmarkArgumentDictionaries = execution.genArgumentDictionaries(LABELS_DICTIONARY, directories, multiclassLabels,
labelsCombinations, oldIndicesMulticlass,
labelsCombinations, indicesMulticlass,
hyperParamSearch, args, kFolds,
statsIterRandomStates, metrics,
argumentDictionaries, benchmark, nbViews, views)
......
......@@ -49,12 +49,13 @@ def initConstants(args, X, classificationIndices, labelsNames, name, directory):
def initTrainTest(X, Y, classificationIndices):
trainIndices, testIndices = classificationIndices
trainIndices, testIndices, testIndicesMulticlass = classificationIndices
X_train = extractSubset(X, trainIndices)
X_test = extractSubset(X, testIndices)
X_test_multiclass = extractSubset(X, testIndicesMulticlass)
y_train = Y[trainIndices]
y_test = Y[testIndices]
return X_train, y_train, X_test, y_test
return X_train, y_train, X_test, y_test, X_test_multiclass
def getKWARGS(classifierModule, hyperParamSearch, nIter, CL_type, X_train, y_train, randomState,
......@@ -127,7 +128,7 @@ def ExecMonoview(directory, X, Y, name, labelsNames, classificationIndices, KFol
+ str(nbCores) + ", algorithm : " + CL_type)
logging.debug("Start:\t Determine Train/Test split")
X_train, y_train, X_test, y_test = initTrainTest(X, Y, classificationIndices)
X_train, y_train, X_test, y_test, X_test_multiclass = initTrainTest(X, Y, classificationIndices)
logging.debug("Info:\t Shape X_train:" + str(X_train.shape) + ", Length of y_train:" + str(len(y_train)))
logging.debug("Info:\t Shape X_test:" + str(X_test.shape) + ", Length of y_test:" + str(len(y_test)))
logging.debug("Done:\t Determine Train/Test split")
......@@ -145,13 +146,14 @@ def ExecMonoview(directory, X, Y, name, labelsNames, classificationIndices, KFol
logging.debug("Done:\t Training")
logging.debug("Start:\t Predicting")
y_train_pred = cl_res.predict(X_train)
y_test_pred = cl_res.predict(X_test)
full_labels_pred = np.zeros(Y.shape, dtype=int)-100
y_train_pred = cl_res.predict(X[classificationIndices[0]])
y_test_pred = cl_res.predict(X[classificationIndices[1]])
for trainIndex, index in enumerate(classificationIndices[0]):
full_labels_pred[index] = y_train_pred[trainIndex]
for testIndex, index in enumerate(classificationIndices[1]):
full_labels_pred[index] = y_test_pred[testIndex]
y_test_multiclass_pred = cl_res.predict(X_test_multiclass)
logging.debug("Done:\t Predicting")
......@@ -174,7 +176,7 @@ def ExecMonoview(directory, X, Y, name, labelsNames, classificationIndices, KFol
logging.info("Done:\t Saving Results")
viewIndex = args["viewIndex"]
return viewIndex, [CL_type, cl_desc + [feat], metricsScores, full_labels_pred, clKWARGS]
return viewIndex, [CL_type, cl_desc + [feat], metricsScores, full_labels_pred, clKWARGS, y_test_multiclass_pred]
if __name__ == '__main__':
......
......@@ -95,7 +95,7 @@ def ExecMultiview(directory, DATASET, name, classificationIndices, KFolds, nbCor
logging.info("Info:\t Extraction duration "+str(extractionTime)+"s")
logging.debug("Start:\t Getting train/test split")
learningIndices, validationIndices = classificationIndices
learningIndices, validationIndices, testIndicesMulticlass = classificationIndices
logging.debug("Done:\t Getting train/test split")
logging.debug("Start:\t Getting classifiers modules")
......@@ -129,6 +129,7 @@ def ExecMultiview(directory, DATASET, name, classificationIndices, KFolds, nbCor
fullLabels[index] = trainLabels[trainIndex]
for testIndex, index in enumerate(validationIndices):
fullLabels[index] = testLabels[testIndex]
testLabelsMulticlass = classifier.predict_hdf5(DATASET, usedIndices=testIndicesMulticlass, viewsIndices=viewsIndices)
logging.info("Done:\t Pertidcting")
classificationTime = time.time() - t_start
......@@ -157,4 +158,4 @@ def ExecMultiview(directory, DATASET, name, classificationIndices, KFolds, nbCor
learningRate, name, imagesAnalysis)
logging.debug("Start:\t Saving preds")
return CL_type, classificationKWARGS, metricsScores, fullLabels
return CL_type, classificationKWARGS, metricsScores, fullLabels, testLabelsMulticlass
......@@ -66,7 +66,7 @@ def execute(classifier, trainLabels,
monoviewClassifiersConfigs = classificationKWARGS["classifiersConfigs"]
fusionMethodConfig = classificationKWARGS["fusionMethodConfig"]
learningIndices, validationIndices = classificationIndices
learningIndices, validationIndices, testIndicesMulticlass = classificationIndices
metricModule = getattr(Metrics, metrics[0][0])
if metrics[0][1] is not None:
metricKWARGS = dict((index, metricConfig) for index, metricConfig in enumerate(metrics[0][1]))
......
......@@ -116,7 +116,7 @@ def getAlgoConfig(classifier, classificationKWARGS, nbCores, viewNames, hyperPar
def getReport(classifier, CLASS_LABELS, classificationIndices, DATASET, trainLabels,
testLabels, viewIndices, metric):
learningIndices, validationIndices = classificationIndices
learningIndices, validationIndices, multiviewTestIndices = classificationIndices
nbView = len(viewIndices)
NB_CLASS = DATASET.get("Metadata").attrs["nbClass"]
metricModule = getattr(Metrics, metric[0])
......@@ -224,7 +224,7 @@ def execute(classifier, trainLabels,
databaseName, KFolds,
hyperParamSearch, nIter, metrics,
viewsIndices, randomState, labels):
learningIndices, validationIndices = classificationIndices
learningIndices, validationIndices, testIndicesMulticlass = classificationIndices
if classifier.classifiersConfigs is None:
metricsScores = getMetricsScores(metrics, trainLabels, testLabels,
validationIndices, learningIndices, labels)
......
......@@ -21,7 +21,8 @@ def genMulticlassLabels(labels, multiclassMethod, classificationIndices):
for iterIndices in classificationIndices]
testIndices = [np.array([oldIndex for oldIndex in oldIndices if oldIndex in iterindices[1]])
for iterindices in classificationIndices]
indicesMulticlass.append([trainIndices, testIndices])
testIndicesMulticlass = [np.array(iterindices[1]) for iterindices in classificationIndices]
indicesMulticlass.append([trainIndices, testIndices, testIndicesMulticlass])
newLabels = np.zeros(len(labels), dtype=int)-100
for labelIndex, label in enumerate(labels):
if label == combination[0]:
......
......@@ -296,7 +296,7 @@ def genDirecortiesNames(directory, statsIter):
return directories
def genArgumentDictionaries(labelsDictionary, directories, multiclassLabels, labelsCombinations, oldIndicesMulticlass, hyperParamSearch, args,
def genArgumentDictionaries(labelsDictionary, directories, multiclassLabels, labelsCombinations, indicesMulticlass, hyperParamSearch, args,
kFolds, statsIterRandomStates, metrics, argumentDictionaries, benchmark, nbViews, views):
benchmarkArgumentDictionaries = []
for combinationIndex, labelsCombination in enumerate(labelsCombinations):
......@@ -307,7 +307,9 @@ def genArgumentDictionaries(labelsDictionary, directories, multiclassLabels, lab
labelsDictionary[labelsCombination[0]]+
"vs"+
labelsDictionary[labelsCombination[1]]+"/",
"classificationIndices": oldIndicesMulticlass[combinationIndex][iterIndex],
"classificationIndices": [indicesMulticlass[combinationIndex][0][iterIndex],
indicesMulticlass[combinationIndex][1][iterIndex],
indicesMulticlass[combinationIndex][2][iterIndex]],
"args": args,
"labels": multiclassLabels[combinationIndex],
"kFolds": kFolds[iterIndex],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment