Skip to content
Snippets Groups Projects
Commit 11816dc7 authored by bbauvin's avatar bbauvin
Browse files

Modified splitting code to keep the same balance on the classes

parent 5f9b5937
Branches
Tags
No related merge requests found
......@@ -296,8 +296,7 @@ def execClassif(arguments):
args.CL_classes)
datasetLength = DATASET.get("Metadata").attrs["datasetLength"]
indices = np.arange(datasetLength)
classificationIndices = execution.genSplits(statsIter, indices, DATASET, args.CL_split, statsIterRandomStates)
classificationIndices = execution.genSplits(statsIter, datasetLength, DATASET, args.CL_split, statsIterRandomStates)
kFolds = execution.genKFolds(statsIter, args.CL_nbFolds, statsIterRandomStates)
......
......@@ -197,6 +197,7 @@ def initRandomState(randomStateArg, directory):
cPickle.dump(randomState, handle)
return randomState
def initLogFile(args):
resultDirectory = "../Results/" + args.name + "/started_" + time.strftime("%Y_%m_%d-%H_%M") + "/"
logFileName = time.strftime("%Y%m%d-%H%M%S") + "-" + ''.join(args.CL_type) + "-" + "_".join(
......@@ -224,20 +225,26 @@ def initLogFile(args):
return resultDirectory
def genSplits(statsIter, indices, DATASET, splitRatio, statsIterRandomStates):
def genSplits(statsIter, datasetlength, DATASET, splitRatio, statsIterRandomStates):
indices = np.arange(datasetlength)
if statsIter > 1:
splits = []
for randomState in statsIterRandomStates:
trainIndices, testIndices, a, b = sklearn.model_selection.train_test_split(indices,
DATASET.get("Labels").value,
test_size=splitRatio,
random_state=randomState)
foldsObj = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, random_state=randomState, test_size=splitRatio)
folds = foldsObj.split(indices, DATASET.get("Labels").value)
for fold in folds:
train_fold, test_fold = fold
trainIndices = indices[train_fold]
testIndices = indices[test_fold]
splits.append([trainIndices, testIndices])
return splits
else:
trainIndices, testIndices, a, b = sklearn.model_selection.train_test_split(indices, DATASET.get("Labels").value,
test_size=splitRatio,
random_state=statsIterRandomStates)
foldsObj = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, random_state=statsIterRandomStates, test_size=splitRatio)
folds = foldsObj.split(indices, DATASET.get("Labels").value)
for fold in folds:
train_fold, test_fold = fold
trainIndices = indices[train_fold]
testIndices = indices[test_fold]
return trainIndices, testIndices
......
......@@ -4,6 +4,8 @@ import os
import h5py
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
from MonoMultiViewClassifiers.utils import execution
......@@ -43,5 +45,23 @@ class Test_initLogFile(unittest.TestCase):
class Test_genSplits(unittest.TestCase):
def test_genSplits_no_iter(self):
pass
\ No newline at end of file
def test_genSplits_no_iter_ratio(self):
X_indices = np.random.randint(0,500,50)
labels = np.zeros(500)
labels[X_indices[:10]] = 1
foldsObj = StratifiedShuffleSplit(n_splits=1, random_state=42, test_size=0.2)
folds = foldsObj.split(X_indices, labels[X_indices])
for fold in folds:
train_fold, test_fold = fold
train_indices = X_indices[train_fold]
test_indices = X_indices[test_fold]
self.assertEqual(len(train_indices), 0.8*50)
self.assertEqual(len(test_indices), 0.2*50)
for index in test_indices:
self.assertIn(index, X_indices)
for index in train_indices:
self.assertIn(index, X_indices)
self.assertGreater(len(np.where(labels[train_indices]==0)[0]), 0)
self.assertGreater(len(np.where(labels[test_indices]==0)[0]), 0)
self.assertGreater(len(np.where(labels[train_indices]==1)[0]), 0)
self.assertGreater(len(np.where(labels[test_indices]==1)[0]), 0)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment