Skip to content
Snippets Groups Projects
Commit 19c50e14 authored by bbauvin's avatar bbauvin
Browse files

Added multiclass test indices

parent 9804bd57
No related branches found
No related tags found
No related merge requests found
......@@ -153,8 +153,10 @@ def ExecMonoview(directory, X, Y, name, labelsNames, classificationIndices, KFol
full_labels_pred[index] = y_train_pred[trainIndex]
for testIndex, index in enumerate(classificationIndices[1]):
full_labels_pred[index] = y_test_pred[testIndex]
if X_test_multiclass:
y_test_multiclass_pred = cl_res.predict(X_test_multiclass)
else:
y_test_multiclass_pred = []
logging.debug("Done:\t Predicting")
t_end = time.time() - t_start
......
......@@ -129,7 +129,10 @@ def ExecMultiview(directory, DATASET, name, classificationIndices, KFolds, nbCor
fullLabels[index] = trainLabels[trainIndex]
for testIndex, index in enumerate(validationIndices):
fullLabels[index] = testLabels[testIndex]
if testIndicesMulticlass:
testLabelsMulticlass = classifier.predict_hdf5(DATASET, usedIndices=testIndicesMulticlass, viewsIndices=viewsIndices)
else:
testLabelsMulticlass = []
logging.info("Done:\t Pertidcting")
classificationTime = time.time() - t_start
......@@ -140,10 +143,6 @@ def ExecMultiview(directory, DATASET, name, classificationIndices, KFolds, nbCor
logging.info("Start:\t Result Analysis for " + CL_type)
times = (extractionTime, classificationTime)
if set(labels[learningIndices])!=set([0,1]) or (set(trainLabels)!=set([0,1]) and set(trainLabels)!=set([0]) and set(trainLabels)!=set([1])):
print(set(labels[learningIndices]))
print(set(trainLabels))
import pdb;pdb.set_trace()
stringAnalysis, imagesAnalysis, metricsScores = analysisModule.execute(classifier, trainLabels,
testLabels, DATASET,
classificationKWARGS, classificationIndices,
......
......@@ -6,6 +6,7 @@ def genMulticlassLabels(labels, multiclassMethod, classificationIndices):
if multiclassMethod == "oneVersusOne":
nbLabels = len(set(list(labels)))
if nbLabels == 2:
classificationIndices = [[trainIndices, testIndices, []] for trainIndices, testIndices in classificationIndices]
return [labels], [(0,1)], [classificationIndices]
else:
combinations = itertools.combinations(np.arange(nbLabels), 2)
......
......@@ -18,7 +18,7 @@ class Test_initConstants(unittest.TestCase):
cls.X = cls.datasetFile.create_dataset("View0", data=cls.X_value)
cls.X.attrs["name"] = "test_dataset"
cls.X.attrs["sparse"] = False
cls.classificationIndices = [np.array([0,2,4,6,8]),np.array([1,3,5,7,9])]
cls.classificationIndices = [np.array([0,2,4,6,8]), np.array([1,3,5,7,9]), np.array([1,3,5,7,9])]
cls.labelsNames = ["test_true", "test_false"]
cls.name = "test"
cls.directory = "Code/Tests/temp_tests/test_dir/"
......@@ -61,10 +61,10 @@ class Test_initTrainTest(unittest.TestCase):
cls.random_state = np.random.RandomState(42)
cls.X = cls.random_state.randint(0,500,(10,5))
cls.Y = cls.random_state.randint(0,2,10)
cls.classificationIndices = [np.array([0,2,4,6,8]),np.array([1,3,5,7,9])]
cls.classificationIndices = [np.array([0,2,4,6,8]),np.array([1,3,5,7,9]), np.array([1,3,5,7,9])]
def test_simple(cls):
X_train, y_train, X_test, y_test = ExecClassifMonoView.initTrainTest(cls.X, cls.Y, cls.classificationIndices)
X_train, y_train, X_test, y_test, X_test_multiclass = ExecClassifMonoView.initTrainTest(cls.X, cls.Y, cls.classificationIndices)
np.testing.assert_array_equal(X_train, np.array([np.array([102,435,348,270,106]),
np.array([466,214,330,458,87]),
np.array([149,308,257,343,491]),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment