Skip to content
Snippets Groups Projects
Commit 1561fb4d authored by bbauvin's avatar bbauvin
Browse files

Modified splitting code to keep the same balance on the classes

parent 11816dc7
Branches
No related tags found
No related merge requests found
......@@ -45,23 +45,29 @@ class Test_initLogFile(unittest.TestCase):
class Test_genSplits(unittest.TestCase):
def setUp(self):
self.X_indices = np.random.randint(0,500,50)
self.labels = np.zeros(500)
self.labels[self.X_indices[:10]] = 1
self.foldsObj = StratifiedShuffleSplit(n_splits=1, random_state=42, test_size=0.2)
self.folds = self.foldsObj.split(self.X_indices, self.labels[self.X_indices])
for fold in self.folds:
self.train_fold, self.test_fold = fold
self.train_indices = self.X_indices[self.train_fold]
self.test_indices = self.X_indices[self.test_fold]
def test_genSplits_no_iter_ratio(self):
X_indices = np.random.randint(0,500,50)
labels = np.zeros(500)
labels[X_indices[:10]] = 1
foldsObj = StratifiedShuffleSplit(n_splits=1, random_state=42, test_size=0.2)
folds = foldsObj.split(X_indices, labels[X_indices])
for fold in folds:
train_fold, test_fold = fold
train_indices = X_indices[train_fold]
test_indices = X_indices[test_fold]
self.assertEqual(len(train_indices), 0.8*50)
self.assertEqual(len(test_indices), 0.2*50)
for index in test_indices:
self.assertIn(index, X_indices)
for index in train_indices:
self.assertIn(index, X_indices)
self.assertGreater(len(np.where(labels[train_indices]==0)[0]), 0)
self.assertGreater(len(np.where(labels[test_indices]==0)[0]), 0)
self.assertGreater(len(np.where(labels[train_indices]==1)[0]), 0)
self.assertGreater(len(np.where(labels[test_indices]==1)[0]), 0)
self.assertEqual(len(self.train_indices), 0.8*50)
self.assertEqual(len(self.test_indices), 0.2*50)
def test_genSplits_no_iter_presence(self):
for index in self.test_indices:
self.assertIn(index, self.X_indices)
for index in self.train_indices:
self.assertIn(index, self.X_indices)
def test_genSplits_no_iter_balance(self):
self.assertGreater(len(np.where(self.labels[self.train_indices]==0)[0]), 0)
self.assertGreater(len(np.where(self.labels[self.test_indices]==0)[0]), 0)
self.assertGreater(len(np.where(self.labels[self.train_indices]==1)[0]), 0)
self.assertGreater(len(np.where(self.labels[self.test_indices]==1)[0]), 0)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment