From 11816dc7533d43bd1b1cf96d6a28483195ca4035 Mon Sep 17 00:00:00 2001
From: bbauvin <baptiste.bauvin@centrale-marseille.fr>
Date: Thu, 19 Oct 2017 11:19:40 -0400
Subject: [PATCH] Modified splitting code to keep the same balance on the
 classes

---
 Code/MonoMultiViewClassifiers/ExecClassif.py  |  3 +--
 .../utils/execution.py                        | 23 +++++++++++-------
 Code/Tests/Test_utils/test_execution.py       | 24 +++++++++++++++++--
 3 files changed, 38 insertions(+), 12 deletions(-)

diff --git a/Code/MonoMultiViewClassifiers/ExecClassif.py b/Code/MonoMultiViewClassifiers/ExecClassif.py
index 8758ddf7..68ff67b4 100644
--- a/Code/MonoMultiViewClassifiers/ExecClassif.py
+++ b/Code/MonoMultiViewClassifiers/ExecClassif.py
@@ -296,8 +296,7 @@ def execClassif(arguments):
                                              args.CL_classes)
 
     datasetLength = DATASET.get("Metadata").attrs["datasetLength"]
-    indices = np.arange(datasetLength)
-    classificationIndices = execution.genSplits(statsIter, indices, DATASET, args.CL_split, statsIterRandomStates)
+    classificationIndices = execution.genSplits(statsIter, datasetLength, DATASET, args.CL_split, statsIterRandomStates)
 
     kFolds = execution.genKFolds(statsIter, args.CL_nbFolds, statsIterRandomStates)
 
diff --git a/Code/MonoMultiViewClassifiers/utils/execution.py b/Code/MonoMultiViewClassifiers/utils/execution.py
index b3761ee7..d4ebb75d 100644
--- a/Code/MonoMultiViewClassifiers/utils/execution.py
+++ b/Code/MonoMultiViewClassifiers/utils/execution.py
@@ -197,6 +197,7 @@ def initRandomState(randomStateArg, directory):
         cPickle.dump(randomState, handle)
     return randomState
 
+
 def initLogFile(args):
     resultDirectory = "../Results/" + args.name + "/started_" + time.strftime("%Y_%m_%d-%H_%M") + "/"
     logFileName = time.strftime("%Y%m%d-%H%M%S") + "-" + ''.join(args.CL_type) + "-" + "_".join(
@@ -224,20 +225,26 @@ def initLogFile(args):
     return resultDirectory
 
 
-def genSplits(statsIter, indices, DATASET, splitRatio, statsIterRandomStates):
+def genSplits(statsIter, datasetlength, DATASET, splitRatio, statsIterRandomStates):
+    indices = np.arange(datasetlength)
     if statsIter > 1:
         splits = []
         for randomState in statsIterRandomStates:
-            trainIndices, testIndices, a, b = sklearn.model_selection.train_test_split(indices,
-                                                                                       DATASET.get("Labels").value,
-                                                                                       test_size=splitRatio,
-                                                                                       random_state=randomState)
+            foldsObj = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, random_state=randomState, test_size=splitRatio)
+            folds = foldsObj.split(indices, DATASET.get("Labels").value)
+            for fold in folds:
+                train_fold, test_fold = fold
+            trainIndices = indices[train_fold]
+            testIndices = indices[test_fold]
             splits.append([trainIndices, testIndices])
         return splits
     else:
-        trainIndices, testIndices, a, b = sklearn.model_selection.train_test_split(indices, DATASET.get("Labels").value,
-                                                                                   test_size=splitRatio,
-                                                                                   random_state=statsIterRandomStates)
+        foldsObj = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, random_state=statsIterRandomStates, test_size=splitRatio)
+        folds = foldsObj.split(indices, DATASET.get("Labels").value)
+        for fold in folds:
+            train_fold, test_fold = fold
+        trainIndices = indices[train_fold]
+        testIndices = indices[test_fold]
         return trainIndices, testIndices
 
 
diff --git a/Code/Tests/Test_utils/test_execution.py b/Code/Tests/Test_utils/test_execution.py
index d88d9e32..3a5d5ef0 100644
--- a/Code/Tests/Test_utils/test_execution.py
+++ b/Code/Tests/Test_utils/test_execution.py
@@ -4,6 +4,8 @@ import os
 import h5py
 import numpy as np
 
+from sklearn.model_selection import StratifiedShuffleSplit
+
 from MonoMultiViewClassifiers.utils import execution
 
 
@@ -43,5 +45,23 @@ class Test_initLogFile(unittest.TestCase):
 
 class Test_genSplits(unittest.TestCase):
 
-    def test_genSplits_no_iter(self):
-        pass
\ No newline at end of file
+    def test_genSplits_no_iter_ratio(self):
+        X_indices = np.random.randint(0,500,50)
+        labels = np.zeros(500)
+        labels[X_indices[:10]] = 1
+        foldsObj = StratifiedShuffleSplit(n_splits=1, random_state=42, test_size=0.2)
+        folds = foldsObj.split(X_indices, labels[X_indices])
+        for fold in folds:
+            train_fold, test_fold = fold
+        train_indices = X_indices[train_fold]
+        test_indices = X_indices[test_fold]
+        self.assertEqual(len(train_indices), 0.8*50)
+        self.assertEqual(len(test_indices), 0.2*50)
+        for index in test_indices:
+            self.assertIn(index, X_indices)
+        for index in train_indices:
+            self.assertIn(index, X_indices)
+        self.assertGreater(len(np.where(labels[train_indices]==0)[0]), 0)
+        self.assertGreater(len(np.where(labels[test_indices]==0)[0]), 0)
+        self.assertGreater(len(np.where(labels[train_indices]==1)[0]), 0)
+        self.assertGreater(len(np.where(labels[test_indices]==1)[0]), 0)
-- 
GitLab