diff --git a/Code/MonoMultiViewClassifiers/ExecClassif.py b/Code/MonoMultiViewClassifiers/ExecClassif.py index 8758ddf74d67495bf57e0e96606b48f3306a965f..68ff67b4d8251504d23d4e8c1b6e39b5c67e4587 100644 --- a/Code/MonoMultiViewClassifiers/ExecClassif.py +++ b/Code/MonoMultiViewClassifiers/ExecClassif.py @@ -296,8 +296,7 @@ def execClassif(arguments): args.CL_classes) datasetLength = DATASET.get("Metadata").attrs["datasetLength"] - indices = np.arange(datasetLength) - classificationIndices = execution.genSplits(statsIter, indices, DATASET, args.CL_split, statsIterRandomStates) + classificationIndices = execution.genSplits(statsIter, datasetLength, DATASET, args.CL_split, statsIterRandomStates) kFolds = execution.genKFolds(statsIter, args.CL_nbFolds, statsIterRandomStates) diff --git a/Code/MonoMultiViewClassifiers/utils/execution.py b/Code/MonoMultiViewClassifiers/utils/execution.py index b3761ee71c228d52f294b8de5e2304b434ba4809..d4ebb75dd82f03b7f77f1de07e38b9232ff0e787 100644 --- a/Code/MonoMultiViewClassifiers/utils/execution.py +++ b/Code/MonoMultiViewClassifiers/utils/execution.py @@ -197,6 +197,7 @@ def initRandomState(randomStateArg, directory): cPickle.dump(randomState, handle) return randomState + def initLogFile(args): resultDirectory = "../Results/" + args.name + "/started_" + time.strftime("%Y_%m_%d-%H_%M") + "/" logFileName = time.strftime("%Y%m%d-%H%M%S") + "-" + ''.join(args.CL_type) + "-" + "_".join( @@ -224,20 +225,26 @@ def initLogFile(args): return resultDirectory -def genSplits(statsIter, indices, DATASET, splitRatio, statsIterRandomStates): +def genSplits(statsIter, datasetlength, DATASET, splitRatio, statsIterRandomStates): + indices = np.arange(datasetlength) if statsIter > 1: splits = [] for randomState in statsIterRandomStates: - trainIndices, testIndices, a, b = sklearn.model_selection.train_test_split(indices, - DATASET.get("Labels").value, - test_size=splitRatio, - random_state=randomState) + foldsObj = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, random_state=randomState, test_size=splitRatio) + folds = foldsObj.split(indices, DATASET.get("Labels").value) + for fold in folds: + train_fold, test_fold = fold + trainIndices = indices[train_fold] + testIndices = indices[test_fold] splits.append([trainIndices, testIndices]) return splits else: - trainIndices, testIndices, a, b = sklearn.model_selection.train_test_split(indices, DATASET.get("Labels").value, - test_size=splitRatio, - random_state=statsIterRandomStates) + foldsObj = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, random_state=statsIterRandomStates, test_size=splitRatio) + folds = foldsObj.split(indices, DATASET.get("Labels").value) + for fold in folds: + train_fold, test_fold = fold + trainIndices = indices[train_fold] + testIndices = indices[test_fold] return trainIndices, testIndices diff --git a/Code/Tests/Test_utils/test_execution.py b/Code/Tests/Test_utils/test_execution.py index d88d9e32dfca553d07f3dbbdfd0e90988c5cc0cc..3a5d5ef0aca4b5a5257a4c4633d84c33d89cac38 100644 --- a/Code/Tests/Test_utils/test_execution.py +++ b/Code/Tests/Test_utils/test_execution.py @@ -4,6 +4,8 @@ import os import h5py import numpy as np +from sklearn.model_selection import StratifiedShuffleSplit + from MonoMultiViewClassifiers.utils import execution @@ -43,5 +45,23 @@ class Test_initLogFile(unittest.TestCase): class Test_genSplits(unittest.TestCase): - def test_genSplits_no_iter(self): - pass \ No newline at end of file + def test_genSplits_no_iter_ratio(self): + X_indices = np.random.randint(0,500,50) + labels = np.zeros(500) + labels[X_indices[:10]] = 1 + foldsObj = StratifiedShuffleSplit(n_splits=1, random_state=42, test_size=0.2) + folds = foldsObj.split(X_indices, labels[X_indices]) + for fold in folds: + train_fold, test_fold = fold + train_indices = X_indices[train_fold] + test_indices = X_indices[test_fold] + self.assertEqual(len(train_indices), 0.8*50) + self.assertEqual(len(test_indices), 0.2*50) + for index in test_indices: + self.assertIn(index, X_indices) + for index in train_indices: + self.assertIn(index, X_indices) + self.assertGreater(len(np.where(labels[train_indices]==0)[0]), 0) + self.assertGreater(len(np.where(labels[test_indices]==0)[0]), 0) + self.assertGreater(len(np.where(labels[train_indices]==1)[0]), 0) + self.assertGreater(len(np.where(labels[test_indices]==1)[0]), 0)