From 1561fb4d25f215fd4cff63bb689f72f5e4ff1b15 Mon Sep 17 00:00:00 2001
From: bbauvin <baptiste.bauvin@centrale-marseille.fr>
Date: Thu, 19 Oct 2017 11:32:30 -0400
Subject: [PATCH] Modified splitting code to keep the same balance on the
 classes

---
 Code/Tests/Test_utils/test_execution.py | 44 ++++++++++++++-----------
 1 file changed, 25 insertions(+), 19 deletions(-)

diff --git a/Code/Tests/Test_utils/test_execution.py b/Code/Tests/Test_utils/test_execution.py
index 3a5d5ef0..0540a0d7 100644
--- a/Code/Tests/Test_utils/test_execution.py
+++ b/Code/Tests/Test_utils/test_execution.py
@@ -45,23 +45,29 @@ class Test_initLogFile(unittest.TestCase):
 
 class Test_genSplits(unittest.TestCase):
 
+    def setUp(self):
+        self.X_indices = np.random.randint(0,500,50)
+        self.labels = np.zeros(500)
+        self.labels[self.X_indices[:10]] = 1
+        self.foldsObj = StratifiedShuffleSplit(n_splits=1, random_state=42, test_size=0.2)
+        self.folds = self.foldsObj.split(self.X_indices, self.labels[self.X_indices])
+        for fold in self.folds:
+            self.train_fold, self.test_fold = fold
+        self.train_indices = self.X_indices[self.train_fold]
+        self.test_indices = self.X_indices[self.test_fold]
+
     def test_genSplits_no_iter_ratio(self):
-        X_indices = np.random.randint(0,500,50)
-        labels = np.zeros(500)
-        labels[X_indices[:10]] = 1
-        foldsObj = StratifiedShuffleSplit(n_splits=1, random_state=42, test_size=0.2)
-        folds = foldsObj.split(X_indices, labels[X_indices])
-        for fold in folds:
-            train_fold, test_fold = fold
-        train_indices = X_indices[train_fold]
-        test_indices = X_indices[test_fold]
-        self.assertEqual(len(train_indices), 0.8*50)
-        self.assertEqual(len(test_indices), 0.2*50)
-        for index in test_indices:
-            self.assertIn(index, X_indices)
-        for index in train_indices:
-            self.assertIn(index, X_indices)
-        self.assertGreater(len(np.where(labels[train_indices]==0)[0]), 0)
-        self.assertGreater(len(np.where(labels[test_indices]==0)[0]), 0)
-        self.assertGreater(len(np.where(labels[train_indices]==1)[0]), 0)
-        self.assertGreater(len(np.where(labels[test_indices]==1)[0]), 0)
+        self.assertEqual(len(self.train_indices), 0.8*50)
+        self.assertEqual(len(self.test_indices), 0.2*50)
+
+    def test_genSplits_no_iter_presence(self):
+        for index in self.test_indices:
+            self.assertIn(index, self.X_indices)
+        for index in self.train_indices:
+            self.assertIn(index, self.X_indices)
+
+    def test_genSplits_no_iter_balance(self):
+        self.assertGreater(len(np.where(self.labels[self.train_indices]==0)[0]), 0)
+        self.assertGreater(len(np.where(self.labels[self.test_indices]==0)[0]), 0)
+        self.assertGreater(len(np.where(self.labels[self.train_indices]==1)[0]), 0)
+        self.assertGreater(len(np.where(self.labels[self.test_indices]==1)[0]), 0)
-- 
GitLab