diff --git a/code/bolsonaro/data/dataset_loader.py b/code/bolsonaro/data/dataset_loader.py index ee15a72b9afd33b8d60eabd5974b8cf8b019cf5b..bac38444f24d5ab67158789ed52d475bec0b985e 100644 --- a/code/bolsonaro/data/dataset_loader.py +++ b/code/bolsonaro/data/dataset_loader.py @@ -39,23 +39,20 @@ class DatasetLoader(object): dataset_loading_func = change_binary_func_load(load_breast_cancer) task = Task.BINARYCLASSIFICATION elif name == 'olivetti_faces': - data = fetch_olivetti_faces(random_state=dataset_parameters.random_state, shuffle=True) - task = Task.MULTICLASSIFICATION - X, y = data.data, data.target - elif name == '20newsgroups': - data = fetch_20newsgroups(random_state=dataset_parameters.random_state, shuffle=True) - #X, y = + dataset = fetch_olivetti_faces(random_state=dataset_parameters.random_state, shuffle=True) task = Task.MULTICLASSIFICATION + X, y = dataset.data, dataset.target elif name == '20newsgroups_vectorized': - dataset_loading_func = fetch_20newsgroups_vectorized + dataset = fetch_20newsgroups_vectorized() + X, y = dataset.data, dataset.target task = Task.MULTICLASSIFICATION elif name == 'lfw_people': - data = fetch_lfw_people() - X, y = data.data, data.target + dataset = fetch_lfw_people() + X, y = dataset.data, dataset.target task = Task.MULTICLASSIFICATION elif name == 'lfw_pairs': - data = fetch_lfw_pairs() - X, y = data.data, data.target + dataset = fetch_lfw_pairs() + X, y = dataset.data, dataset.target task = Task.MULTICLASSIFICATION elif name == 'covtype': X, y = fetch_covtype(random_state=dataset_parameters.random_state, shuffle=True, return_X_y=True) @@ -63,9 +60,6 @@ class DatasetLoader(object): elif name == 'rcv1': X, y = fetch_rcv1(random_state=dataset_parameters.random_state, shuffle=True, return_X_y=True) task = Task.MULTICLASSIFICATION - elif name == 'kddcup99': - X, y = fetch_kddcup99(random_state=dataset_parameters.random_state, shuffle=True, return_X_y=True) - task = Task.MULTICLASSIFICATION elif name == 'california_housing': X, y = fetch_california_housing(return_X_y=True) task = Task.REGRESSION