From 11816dc7533d43bd1b1cf96d6a28483195ca4035 Mon Sep 17 00:00:00 2001 From: bbauvin <baptiste.bauvin@centrale-marseille.fr> Date: Thu, 19 Oct 2017 11:19:40 -0400 Subject: [PATCH] Modified splitting code to keep the same balance on the classes --- Code/MonoMultiViewClassifiers/ExecClassif.py | 3 +-- .../utils/execution.py | 23 +++++++++++------- Code/Tests/Test_utils/test_execution.py | 24 +++++++++++++++++-- 3 files changed, 38 insertions(+), 12 deletions(-) diff --git a/Code/MonoMultiViewClassifiers/ExecClassif.py b/Code/MonoMultiViewClassifiers/ExecClassif.py index 8758ddf7..68ff67b4 100644 --- a/Code/MonoMultiViewClassifiers/ExecClassif.py +++ b/Code/MonoMultiViewClassifiers/ExecClassif.py @@ -296,8 +296,7 @@ def execClassif(arguments): args.CL_classes) datasetLength = DATASET.get("Metadata").attrs["datasetLength"] - indices = np.arange(datasetLength) - classificationIndices = execution.genSplits(statsIter, indices, DATASET, args.CL_split, statsIterRandomStates) + classificationIndices = execution.genSplits(statsIter, datasetLength, DATASET, args.CL_split, statsIterRandomStates) kFolds = execution.genKFolds(statsIter, args.CL_nbFolds, statsIterRandomStates) diff --git a/Code/MonoMultiViewClassifiers/utils/execution.py b/Code/MonoMultiViewClassifiers/utils/execution.py index b3761ee7..d4ebb75d 100644 --- a/Code/MonoMultiViewClassifiers/utils/execution.py +++ b/Code/MonoMultiViewClassifiers/utils/execution.py @@ -197,6 +197,7 @@ def initRandomState(randomStateArg, directory): cPickle.dump(randomState, handle) return randomState + def initLogFile(args): resultDirectory = "../Results/" + args.name + "/started_" + time.strftime("%Y_%m_%d-%H_%M") + "/" logFileName = time.strftime("%Y%m%d-%H%M%S") + "-" + ''.join(args.CL_type) + "-" + "_".join( @@ -224,20 +225,26 @@ def initLogFile(args): return resultDirectory -def genSplits(statsIter, indices, DATASET, splitRatio, statsIterRandomStates): +def genSplits(statsIter, datasetlength, DATASET, splitRatio, statsIterRandomStates): + indices = np.arange(datasetlength) if statsIter > 1: splits = [] for randomState in statsIterRandomStates: - trainIndices, testIndices, a, b = sklearn.model_selection.train_test_split(indices, - DATASET.get("Labels").value, - test_size=splitRatio, - random_state=randomState) + foldsObj = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, random_state=randomState, test_size=splitRatio) + folds = foldsObj.split(indices, DATASET.get("Labels").value) + for fold in folds: + train_fold, test_fold = fold + trainIndices = indices[train_fold] + testIndices = indices[test_fold] splits.append([trainIndices, testIndices]) return splits else: - trainIndices, testIndices, a, b = sklearn.model_selection.train_test_split(indices, DATASET.get("Labels").value, - test_size=splitRatio, - random_state=statsIterRandomStates) + foldsObj = sklearn.model_selection.StratifiedShuffleSplit(n_splits=1, random_state=statsIterRandomStates, test_size=splitRatio) + folds = foldsObj.split(indices, DATASET.get("Labels").value) + for fold in folds: + train_fold, test_fold = fold + trainIndices = indices[train_fold] + testIndices = indices[test_fold] return trainIndices, testIndices diff --git a/Code/Tests/Test_utils/test_execution.py b/Code/Tests/Test_utils/test_execution.py index d88d9e32..3a5d5ef0 100644 --- a/Code/Tests/Test_utils/test_execution.py +++ b/Code/Tests/Test_utils/test_execution.py @@ -4,6 +4,8 @@ import os import h5py import numpy as np +from sklearn.model_selection import StratifiedShuffleSplit + from MonoMultiViewClassifiers.utils import execution @@ -43,5 +45,23 @@ class Test_initLogFile(unittest.TestCase): class Test_genSplits(unittest.TestCase): - def test_genSplits_no_iter(self): - pass \ No newline at end of file + def test_genSplits_no_iter_ratio(self): + X_indices = np.random.randint(0,500,50) + labels = np.zeros(500) + labels[X_indices[:10]] = 1 + foldsObj = StratifiedShuffleSplit(n_splits=1, random_state=42, test_size=0.2) + folds = foldsObj.split(X_indices, labels[X_indices]) + for fold in folds: + train_fold, test_fold = fold + train_indices = X_indices[train_fold] + test_indices = X_indices[test_fold] + self.assertEqual(len(train_indices), 0.8*50) + self.assertEqual(len(test_indices), 0.2*50) + for index in test_indices: + self.assertIn(index, X_indices) + for index in train_indices: + self.assertIn(index, X_indices) + self.assertGreater(len(np.where(labels[train_indices]==0)[0]), 0) + self.assertGreater(len(np.where(labels[test_indices]==0)[0]), 0) + self.assertGreater(len(np.where(labels[train_indices]==1)[0]), 0) + self.assertGreater(len(np.where(labels[test_indices]==1)[0]), 0) -- GitLab