Skip to content
Snippets Groups Projects
Commit 19c50e14 authored by bbauvin's avatar bbauvin
Browse files

Added multiclass test indices

parent 9804bd57
Branches
No related tags found
No related merge requests found
...@@ -153,8 +153,10 @@ def ExecMonoview(directory, X, Y, name, labelsNames, classificationIndices, KFol ...@@ -153,8 +153,10 @@ def ExecMonoview(directory, X, Y, name, labelsNames, classificationIndices, KFol
full_labels_pred[index] = y_train_pred[trainIndex] full_labels_pred[index] = y_train_pred[trainIndex]
for testIndex, index in enumerate(classificationIndices[1]): for testIndex, index in enumerate(classificationIndices[1]):
full_labels_pred[index] = y_test_pred[testIndex] full_labels_pred[index] = y_test_pred[testIndex]
if X_test_multiclass:
y_test_multiclass_pred = cl_res.predict(X_test_multiclass) y_test_multiclass_pred = cl_res.predict(X_test_multiclass)
else:
y_test_multiclass_pred = []
logging.debug("Done:\t Predicting") logging.debug("Done:\t Predicting")
t_end = time.time() - t_start t_end = time.time() - t_start
......
...@@ -129,7 +129,10 @@ def ExecMultiview(directory, DATASET, name, classificationIndices, KFolds, nbCor ...@@ -129,7 +129,10 @@ def ExecMultiview(directory, DATASET, name, classificationIndices, KFolds, nbCor
fullLabels[index] = trainLabels[trainIndex] fullLabels[index] = trainLabels[trainIndex]
for testIndex, index in enumerate(validationIndices): for testIndex, index in enumerate(validationIndices):
fullLabels[index] = testLabels[testIndex] fullLabels[index] = testLabels[testIndex]
if testIndicesMulticlass:
testLabelsMulticlass = classifier.predict_hdf5(DATASET, usedIndices=testIndicesMulticlass, viewsIndices=viewsIndices) testLabelsMulticlass = classifier.predict_hdf5(DATASET, usedIndices=testIndicesMulticlass, viewsIndices=viewsIndices)
else:
testLabelsMulticlass = []
logging.info("Done:\t Pertidcting") logging.info("Done:\t Pertidcting")
classificationTime = time.time() - t_start classificationTime = time.time() - t_start
...@@ -140,10 +143,6 @@ def ExecMultiview(directory, DATASET, name, classificationIndices, KFolds, nbCor ...@@ -140,10 +143,6 @@ def ExecMultiview(directory, DATASET, name, classificationIndices, KFolds, nbCor
logging.info("Start:\t Result Analysis for " + CL_type) logging.info("Start:\t Result Analysis for " + CL_type)
times = (extractionTime, classificationTime) times = (extractionTime, classificationTime)
if set(labels[learningIndices])!=set([0,1]) or (set(trainLabels)!=set([0,1]) and set(trainLabels)!=set([0]) and set(trainLabels)!=set([1])):
print(set(labels[learningIndices]))
print(set(trainLabels))
import pdb;pdb.set_trace()
stringAnalysis, imagesAnalysis, metricsScores = analysisModule.execute(classifier, trainLabels, stringAnalysis, imagesAnalysis, metricsScores = analysisModule.execute(classifier, trainLabels,
testLabels, DATASET, testLabels, DATASET,
classificationKWARGS, classificationIndices, classificationKWARGS, classificationIndices,
......
...@@ -6,6 +6,7 @@ def genMulticlassLabels(labels, multiclassMethod, classificationIndices): ...@@ -6,6 +6,7 @@ def genMulticlassLabels(labels, multiclassMethod, classificationIndices):
if multiclassMethod == "oneVersusOne": if multiclassMethod == "oneVersusOne":
nbLabels = len(set(list(labels))) nbLabels = len(set(list(labels)))
if nbLabels == 2: if nbLabels == 2:
classificationIndices = [[trainIndices, testIndices, []] for trainIndices, testIndices in classificationIndices]
return [labels], [(0,1)], [classificationIndices] return [labels], [(0,1)], [classificationIndices]
else: else:
combinations = itertools.combinations(np.arange(nbLabels), 2) combinations = itertools.combinations(np.arange(nbLabels), 2)
......
...@@ -18,7 +18,7 @@ class Test_initConstants(unittest.TestCase): ...@@ -18,7 +18,7 @@ class Test_initConstants(unittest.TestCase):
cls.X = cls.datasetFile.create_dataset("View0", data=cls.X_value) cls.X = cls.datasetFile.create_dataset("View0", data=cls.X_value)
cls.X.attrs["name"] = "test_dataset" cls.X.attrs["name"] = "test_dataset"
cls.X.attrs["sparse"] = False cls.X.attrs["sparse"] = False
cls.classificationIndices = [np.array([0,2,4,6,8]),np.array([1,3,5,7,9])] cls.classificationIndices = [np.array([0,2,4,6,8]), np.array([1,3,5,7,9]), np.array([1,3,5,7,9])]
cls.labelsNames = ["test_true", "test_false"] cls.labelsNames = ["test_true", "test_false"]
cls.name = "test" cls.name = "test"
cls.directory = "Code/Tests/temp_tests/test_dir/" cls.directory = "Code/Tests/temp_tests/test_dir/"
...@@ -61,10 +61,10 @@ class Test_initTrainTest(unittest.TestCase): ...@@ -61,10 +61,10 @@ class Test_initTrainTest(unittest.TestCase):
cls.random_state = np.random.RandomState(42) cls.random_state = np.random.RandomState(42)
cls.X = cls.random_state.randint(0,500,(10,5)) cls.X = cls.random_state.randint(0,500,(10,5))
cls.Y = cls.random_state.randint(0,2,10) cls.Y = cls.random_state.randint(0,2,10)
cls.classificationIndices = [np.array([0,2,4,6,8]),np.array([1,3,5,7,9])] cls.classificationIndices = [np.array([0,2,4,6,8]),np.array([1,3,5,7,9]), np.array([1,3,5,7,9])]
def test_simple(cls): def test_simple(cls):
X_train, y_train, X_test, y_test = ExecClassifMonoView.initTrainTest(cls.X, cls.Y, cls.classificationIndices) X_train, y_train, X_test, y_test, X_test_multiclass = ExecClassifMonoView.initTrainTest(cls.X, cls.Y, cls.classificationIndices)
np.testing.assert_array_equal(X_train, np.array([np.array([102,435,348,270,106]), np.testing.assert_array_equal(X_train, np.array([np.array([102,435,348,270,106]),
np.array([466,214,330,458,87]), np.array([466,214,330,458,87]),
np.array([149,308,257,343,491]), np.array([149,308,257,343,491]),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment