From 19c50e142d0ba24c7352623f5d3aefe26fcf6f96 Mon Sep 17 00:00:00 2001
From: bbauvin <baptiste.bauvin@centrale-marseille.fr>
Date: Thu, 9 Nov 2017 11:54:12 -0500
Subject: [PATCH] Added multiclass test indices

---
 .../Monoview/ExecClassifMonoView.py                      | 6 ++++--
 Code/MonoMultiViewClassifiers/Multiview/ExecMultiview.py | 9 ++++-----
 Code/MonoMultiViewClassifiers/utils/Multiclass.py        | 1 +
 Code/Tests/Test_MonoView/test_ExecClassifMonoView.py     | 6 +++---
 4 files changed, 12 insertions(+), 10 deletions(-)

diff --git a/Code/MonoMultiViewClassifiers/Monoview/ExecClassifMonoView.py b/Code/MonoMultiViewClassifiers/Monoview/ExecClassifMonoView.py
index 0926a3e3..376cb113 100644
--- a/Code/MonoMultiViewClassifiers/Monoview/ExecClassifMonoView.py
+++ b/Code/MonoMultiViewClassifiers/Monoview/ExecClassifMonoView.py
@@ -153,8 +153,10 @@ def ExecMonoview(directory, X, Y, name, labelsNames, classificationIndices, KFol
         full_labels_pred[index] = y_train_pred[trainIndex]
     for testIndex, index in enumerate(classificationIndices[1]):
         full_labels_pred[index] = y_test_pred[testIndex]
-    y_test_multiclass_pred = cl_res.predict(X_test_multiclass)
-
+    if X_test_multiclass:
+        y_test_multiclass_pred = cl_res.predict(X_test_multiclass)
+    else:
+        y_test_multiclass_pred = []
     logging.debug("Done:\t Predicting")
 
     t_end = time.time() - t_start
diff --git a/Code/MonoMultiViewClassifiers/Multiview/ExecMultiview.py b/Code/MonoMultiViewClassifiers/Multiview/ExecMultiview.py
index 388ebc70..018b38fa 100644
--- a/Code/MonoMultiViewClassifiers/Multiview/ExecMultiview.py
+++ b/Code/MonoMultiViewClassifiers/Multiview/ExecMultiview.py
@@ -129,7 +129,10 @@ def ExecMultiview(directory, DATASET, name, classificationIndices, KFolds, nbCor
         fullLabels[index] = trainLabels[trainIndex]
     for testIndex, index in enumerate(validationIndices):
         fullLabels[index] = testLabels[testIndex]
-    testLabelsMulticlass = classifier.predict_hdf5(DATASET, usedIndices=testIndicesMulticlass, viewsIndices=viewsIndices)
+    if testIndicesMulticlass:
+        testLabelsMulticlass = classifier.predict_hdf5(DATASET, usedIndices=testIndicesMulticlass, viewsIndices=viewsIndices)
+    else:
+        testLabelsMulticlass = []
     logging.info("Done:\t Pertidcting")
 
     classificationTime = time.time() - t_start
@@ -140,10 +143,6 @@ def ExecMultiview(directory, DATASET, name, classificationIndices, KFolds, nbCor
 
     logging.info("Start:\t Result Analysis for " + CL_type)
     times = (extractionTime, classificationTime)
-    if set(labels[learningIndices])!=set([0,1]) or (set(trainLabels)!=set([0,1]) and set(trainLabels)!=set([0]) and set(trainLabels)!=set([1])):
-        print(set(labels[learningIndices]))
-        print(set(trainLabels))
-        import pdb;pdb.set_trace()
     stringAnalysis, imagesAnalysis, metricsScores = analysisModule.execute(classifier, trainLabels,
                                                                            testLabels, DATASET,
                                                                            classificationKWARGS, classificationIndices,
diff --git a/Code/MonoMultiViewClassifiers/utils/Multiclass.py b/Code/MonoMultiViewClassifiers/utils/Multiclass.py
index 9bbd793f..76a7a6bc 100644
--- a/Code/MonoMultiViewClassifiers/utils/Multiclass.py
+++ b/Code/MonoMultiViewClassifiers/utils/Multiclass.py
@@ -6,6 +6,7 @@ def genMulticlassLabels(labels, multiclassMethod, classificationIndices):
     if multiclassMethod == "oneVersusOne":
         nbLabels = len(set(list(labels)))
         if nbLabels == 2:
+            classificationIndices = [[trainIndices, testIndices, []] for trainIndices, testIndices in classificationIndices]
             return [labels], [(0,1)], [classificationIndices]
         else:
             combinations = itertools.combinations(np.arange(nbLabels), 2)
diff --git a/Code/Tests/Test_MonoView/test_ExecClassifMonoView.py b/Code/Tests/Test_MonoView/test_ExecClassifMonoView.py
index e0335980..e9879ed7 100644
--- a/Code/Tests/Test_MonoView/test_ExecClassifMonoView.py
+++ b/Code/Tests/Test_MonoView/test_ExecClassifMonoView.py
@@ -18,7 +18,7 @@ class Test_initConstants(unittest.TestCase):
         cls.X = cls.datasetFile.create_dataset("View0", data=cls.X_value)
         cls.X.attrs["name"] = "test_dataset"
         cls.X.attrs["sparse"] = False
-        cls.classificationIndices = [np.array([0,2,4,6,8]),np.array([1,3,5,7,9])]
+        cls.classificationIndices = [np.array([0,2,4,6,8]), np.array([1,3,5,7,9]), np.array([1,3,5,7,9])]
         cls.labelsNames = ["test_true", "test_false"]
         cls.name = "test"
         cls.directory = "Code/Tests/temp_tests/test_dir/"
@@ -61,10 +61,10 @@ class Test_initTrainTest(unittest.TestCase):
         cls.random_state = np.random.RandomState(42)
         cls.X = cls.random_state.randint(0,500,(10,5))
         cls.Y = cls.random_state.randint(0,2,10)
-        cls.classificationIndices = [np.array([0,2,4,6,8]),np.array([1,3,5,7,9])]
+        cls.classificationIndices = [np.array([0,2,4,6,8]),np.array([1,3,5,7,9]), np.array([1,3,5,7,9])]
 
     def test_simple(cls):
-        X_train, y_train, X_test, y_test = ExecClassifMonoView.initTrainTest(cls.X, cls.Y, cls.classificationIndices)
+        X_train, y_train, X_test, y_test, X_test_multiclass = ExecClassifMonoView.initTrainTest(cls.X, cls.Y, cls.classificationIndices)
         np.testing.assert_array_equal(X_train, np.array([np.array([102,435,348,270,106]),
                                                          np.array([466,214,330,458,87]),
                                                          np.array([149,308,257,343,491]),
-- 
GitLab