From 967742a64d98e4ece17316efc229253025e22907 Mon Sep 17 00:00:00 2001 From: Charly Lamothe <charly.lamothe@univ-amu.fr> Date: Sun, 1 Dec 2019 11:08:46 +0100 Subject: [PATCH] Finish to fix sklearn dataset fetchers. Keep 20newsgroups_vectorized instead of non vectorized version that is the preprocessed version of this text dataset. Remove kddcup99 dataset. --- code/bolsonaro/data/dataset_loader.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/code/bolsonaro/data/dataset_loader.py b/code/bolsonaro/data/dataset_loader.py index ee15a72..bac3844 100644 --- a/code/bolsonaro/data/dataset_loader.py +++ b/code/bolsonaro/data/dataset_loader.py @@ -39,23 +39,20 @@ class DatasetLoader(object): dataset_loading_func = change_binary_func_load(load_breast_cancer) task = Task.BINARYCLASSIFICATION elif name == 'olivetti_faces': - data = fetch_olivetti_faces(random_state=dataset_parameters.random_state, shuffle=True) - task = Task.MULTICLASSIFICATION - X, y = data.data, data.target - elif name == '20newsgroups': - data = fetch_20newsgroups(random_state=dataset_parameters.random_state, shuffle=True) - #X, y = + dataset = fetch_olivetti_faces(random_state=dataset_parameters.random_state, shuffle=True) task = Task.MULTICLASSIFICATION + X, y = dataset.data, dataset.target elif name == '20newsgroups_vectorized': - dataset_loading_func = fetch_20newsgroups_vectorized + dataset = fetch_20newsgroups_vectorized() + X, y = dataset.data, dataset.target task = Task.MULTICLASSIFICATION elif name == 'lfw_people': - data = fetch_lfw_people() - X, y = data.data, data.target + dataset = fetch_lfw_people() + X, y = dataset.data, dataset.target task = Task.MULTICLASSIFICATION elif name == 'lfw_pairs': - data = fetch_lfw_pairs() - X, y = data.data, data.target + dataset = fetch_lfw_pairs() + X, y = dataset.data, dataset.target task = Task.MULTICLASSIFICATION elif name == 'covtype': X, y = fetch_covtype(random_state=dataset_parameters.random_state, shuffle=True, return_X_y=True) @@ -63,9 +60,6 @@ class DatasetLoader(object): elif name == 'rcv1': X, y = fetch_rcv1(random_state=dataset_parameters.random_state, shuffle=True, return_X_y=True) task = Task.MULTICLASSIFICATION - elif name == 'kddcup99': - X, y = fetch_kddcup99(random_state=dataset_parameters.random_state, shuffle=True, return_X_y=True) - task = Task.MULTICLASSIFICATION elif name == 'california_housing': X, y = fetch_california_housing(return_X_y=True) task = Task.REGRESSION -- GitLab