Skip to content
Snippets Groups Projects
Commit 11816dc7 authored by bbauvin's avatar bbauvin
Browse files

Modified splitting code to keep the same balance on the classes

parent 5f9b5937
No related branches found
No related tags found
No related merge requests found
...@@ -296,8 +296,7 @@ def execClassif(arguments): ...@@ -296,8 +296,7 @@ def execClassif(arguments):
args.CL_classes) args.CL_classes)
datasetLength = DATASET.get("Metadata").attrs["datasetLength"] datasetLength = DATASET.get("Metadata").attrs["datasetLength"]
indices = np.arange(datasetLength) classificationIndices = execution.genSplits(statsIter, datasetLength, DATASET, args.CL_split, statsIterRandomStates)
classificationIndices = execution.genSplits(statsIter, indices, DATASET, args.CL_split, statsIterRandomStates)
kFolds = execution.genKFolds(statsIter, args.CL_nbFolds, statsIterRandomStates) kFolds = execution.genKFolds(statsIter, args.CL_nbFolds, statsIterRandomStates)
......
...@@ -197,6 +197,7 @@ def initRandomState(randomStateArg, directory): ...@@ -197,6 +197,7 @@ def initRandomState(randomStateArg, directory):
cPickle.dump(randomState, handle) cPickle.dump(randomState, handle)
return randomState return randomState
def initLogFile(args): def initLogFile(args):
resultDirectory = "../Results/" + args.name + "/started_" + time.strftime("%Y_%m_%d-%H_%M") + "/" resultDirectory = "../Results/" + args.name + "/started_" + time.strftime("%Y_%m_%d-%H_%M") + "/"
logFileName = time.strftime("%Y%m%d-%H%M%S") + "-" + ''.join(args.CL_type) + "-" + "_".join( logFileName = time.strftime("%Y%m%d-%H%M%S") + "-" + ''.join(args.CL_type) + "-" + "_".join(
...@@ -224,20 +225,26 @@ def initLogFile(args): ...@@ -224,20 +225,26 @@ def initLogFile(args):
return resultDirectory return resultDirectory
def genSplits(statsIter, indices, DATASET, splitRatio, statsIterRandomStates): def genSplits(statsIter, datasetlength, DATASET, splitRatio, statsIterRandomStates):
indices = np.arange(datasetlength)
if statsIter > 1: if statsIter > 1:
splits = [] splits = []
for randomState in statsIterRandomStates: for randomState in statsIterRandomStates:
trainIndices, testIndices, a, b = sklearn.model_selection.train_test_split(indices, foldsObj = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, random_state=randomState, test_size=splitRatio)
DATASET.get("Labels").value, folds = foldsObj.split(indices, DATASET.get("Labels").value)
test_size=splitRatio, for fold in folds:
random_state=randomState) train_fold, test_fold = fold
trainIndices = indices[train_fold]
testIndices = indices[test_fold]
splits.append([trainIndices, testIndices]) splits.append([trainIndices, testIndices])
return splits return splits
else: else:
trainIndices, testIndices, a, b = sklearn.model_selection.train_test_split(indices, DATASET.get("Labels").value, foldsObj = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, random_state=statsIterRandomStates, test_size=splitRatio)
test_size=splitRatio, folds = foldsObj.split(indices, DATASET.get("Labels").value)
random_state=statsIterRandomStates) for fold in folds:
train_fold, test_fold = fold
trainIndices = indices[train_fold]
testIndices = indices[test_fold]
return trainIndices, testIndices return trainIndices, testIndices
......
...@@ -4,6 +4,8 @@ import os ...@@ -4,6 +4,8 @@ import os
import h5py import h5py
import numpy as np import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
from MonoMultiViewClassifiers.utils import execution from MonoMultiViewClassifiers.utils import execution
...@@ -43,5 +45,23 @@ class Test_initLogFile(unittest.TestCase): ...@@ -43,5 +45,23 @@ class Test_initLogFile(unittest.TestCase):
class Test_genSplits(unittest.TestCase): class Test_genSplits(unittest.TestCase):
def test_genSplits_no_iter(self): def test_genSplits_no_iter_ratio(self):
pass X_indices = np.random.randint(0,500,50)
\ No newline at end of file labels = np.zeros(500)
labels[X_indices[:10]] = 1
foldsObj = StratifiedShuffleSplit(n_splits=1, random_state=42, test_size=0.2)
folds = foldsObj.split(X_indices, labels[X_indices])
for fold in folds:
train_fold, test_fold = fold
train_indices = X_indices[train_fold]
test_indices = X_indices[test_fold]
self.assertEqual(len(train_indices), 0.8*50)
self.assertEqual(len(test_indices), 0.2*50)
for index in test_indices:
self.assertIn(index, X_indices)
for index in train_indices:
self.assertIn(index, X_indices)
self.assertGreater(len(np.where(labels[train_indices]==0)[0]), 0)
self.assertGreater(len(np.where(labels[test_indices]==0)[0]), 0)
self.assertGreater(len(np.where(labels[train_indices]==1)[0]), 0)
self.assertGreater(len(np.where(labels[test_indices]==1)[0]), 0)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment