From 1561fb4d25f215fd4cff63bb689f72f5e4ff1b15 Mon Sep 17 00:00:00 2001 From: bbauvin <baptiste.bauvin@centrale-marseille.fr> Date: Thu, 19 Oct 2017 11:32:30 -0400 Subject: [PATCH] Modified splitting code to keep the same balance on the classes --- Code/Tests/Test_utils/test_execution.py | 44 ++++++++++++++----------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/Code/Tests/Test_utils/test_execution.py b/Code/Tests/Test_utils/test_execution.py index 3a5d5ef0..0540a0d7 100644 --- a/Code/Tests/Test_utils/test_execution.py +++ b/Code/Tests/Test_utils/test_execution.py @@ -45,23 +45,29 @@ class Test_initLogFile(unittest.TestCase): class Test_genSplits(unittest.TestCase): + def setUp(self): + self.X_indices = np.random.randint(0,500,50) + self.labels = np.zeros(500) + self.labels[self.X_indices[:10]] = 1 + self.foldsObj = StratifiedShuffleSplit(n_splits=1, random_state=42, test_size=0.2) + self.folds = self.foldsObj.split(self.X_indices, self.labels[self.X_indices]) + for fold in self.folds: + self.train_fold, self.test_fold = fold + self.train_indices = self.X_indices[self.train_fold] + self.test_indices = self.X_indices[self.test_fold] + def test_genSplits_no_iter_ratio(self): - X_indices = np.random.randint(0,500,50) - labels = np.zeros(500) - labels[X_indices[:10]] = 1 - foldsObj = StratifiedShuffleSplit(n_splits=1, random_state=42, test_size=0.2) - folds = foldsObj.split(X_indices, labels[X_indices]) - for fold in folds: - train_fold, test_fold = fold - train_indices = X_indices[train_fold] - test_indices = X_indices[test_fold] - self.assertEqual(len(train_indices), 0.8*50) - self.assertEqual(len(test_indices), 0.2*50) - for index in test_indices: - self.assertIn(index, X_indices) - for index in train_indices: - self.assertIn(index, X_indices) - self.assertGreater(len(np.where(labels[train_indices]==0)[0]), 0) - self.assertGreater(len(np.where(labels[test_indices]==0)[0]), 0) - self.assertGreater(len(np.where(labels[train_indices]==1)[0]), 0) - self.assertGreater(len(np.where(labels[test_indices]==1)[0]), 0) + self.assertEqual(len(self.train_indices), 0.8*50) + self.assertEqual(len(self.test_indices), 0.2*50) + + def test_genSplits_no_iter_presence(self): + for index in self.test_indices: + self.assertIn(index, self.X_indices) + for index in self.train_indices: + self.assertIn(index, self.X_indices) + + def test_genSplits_no_iter_balance(self): + self.assertGreater(len(np.where(self.labels[self.train_indices]==0)[0]), 0) + self.assertGreater(len(np.where(self.labels[self.test_indices]==0)[0]), 0) + self.assertGreater(len(np.where(self.labels[self.train_indices]==1)[0]), 0) + self.assertGreater(len(np.where(self.labels[self.test_indices]==1)[0]), 0) -- GitLab