diff --git a/code/bolsonaro/data/dataset_loader.py b/code/bolsonaro/data/dataset_loader.py index b22ecaa83f6e69610c82796e068e91db23e83646..ee15a72b9afd33b8d60eabd5974b8cf8b019cf5b 100644 --- a/code/bolsonaro/data/dataset_loader.py +++ b/code/bolsonaro/data/dataset_loader.py @@ -16,6 +16,7 @@ class DatasetLoader(object): @staticmethod def load(dataset_parameters): name = dataset_parameters.name + X, y = None, None if name == 'boston': dataset_loading_func = load_boston task = Task.REGRESSION @@ -37,37 +38,43 @@ class DatasetLoader(object): elif name == 'breast_cancer': dataset_loading_func = change_binary_func_load(load_breast_cancer) task = Task.BINARYCLASSIFICATION - elif name == 'olivetti_faces': # bug (no return X_y) - dataset_loading_func = fetch_olivetti_faces + elif name == 'olivetti_faces': + data = fetch_olivetti_faces(random_state=dataset_parameters.random_state, shuffle=True) task = Task.MULTICLASSIFICATION - elif name == '20newsgroups': # bug (no return X_y) - dataset_loading_func = fetch_20newsgroups + X, y = data.data, data.target + elif name == '20newsgroups': + data = fetch_20newsgroups(random_state=dataset_parameters.random_state, shuffle=True) + #X, y = task = Task.MULTICLASSIFICATION elif name == '20newsgroups_vectorized': dataset_loading_func = fetch_20newsgroups_vectorized task = Task.MULTICLASSIFICATION - elif name == 'lfw_people': # needs PIL (image dataset) - dataset_loading_func = fetch_lfw_people + elif name == 'lfw_people': + data = fetch_lfw_people() + X, y = data.data, data.target task = Task.MULTICLASSIFICATION elif name == 'lfw_pairs': - dataset_loading_func = fetch_lfw_pairs + data = fetch_lfw_pairs() + X, y = data.data, data.target task = Task.MULTICLASSIFICATION elif name == 'covtype': - dataset_loading_func = fetch_covtype + X, y = fetch_covtype(random_state=dataset_parameters.random_state, shuffle=True, return_X_y=True) task = Task.MULTICLASSIFICATION elif name == 'rcv1': - dataset_loading_func = fetch_rcv1 + X, y = fetch_rcv1(random_state=dataset_parameters.random_state, shuffle=True, return_X_y=True) task = Task.MULTICLASSIFICATION elif name == 'kddcup99': - dataset_loading_func = fetch_kddcup99 + X, y = fetch_kddcup99(random_state=dataset_parameters.random_state, shuffle=True, return_X_y=True) task = Task.MULTICLASSIFICATION elif name == 'california_housing': - dataset_loading_func = fetch_california_housing + X, y = fetch_california_housing(return_X_y=True) task = Task.REGRESSION else: raise ValueError("Unsupported dataset '{}'".format(name)) - X, y = dataset_loading_func(return_X_y=True) + if X is None: + X, y = dataset_loading_func() + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=dataset_parameters.test_size, random_state=dataset_parameters.random_state) diff --git a/code/bolsonaro/utils.py b/code/bolsonaro/utils.py index d7509ad9e85cde3cc0c649f85cfb5b60ead9a854..10ea76921ffacdd814044fc8179eb83717429330 100644 --- a/code/bolsonaro/utils.py +++ b/code/bolsonaro/utils.py @@ -66,8 +66,11 @@ def binarize_class_data(data, class_pos, inplace=True): return data def change_binary_func_load(base_load_function): - def func_load(return_X_y): - X, y = base_load_function(return_X_y=return_X_y) + def func_load(return_X_y, random_state=None): + if random_state: + X, y = base_load_function(return_X_y=return_X_y, random_state=random_state) + else: + X, y = base_load_function(return_X_y=return_X_y) possible_classes = sorted(set(y)) assert len(possible_classes) == 2, "Function change binary_func_load only work for binary classfication" y = binarize_class_data(y, possible_classes[-1])