diff --git a/code/bolsonaro/data/dataset_loader.py b/code/bolsonaro/data/dataset_loader.py index 7d05f200728c275b5781e61fd85516d3361c46ef..0bebeaaf6c9f0dac2258a384a8b2bcff50c4c5b3 100644 --- a/code/bolsonaro/data/dataset_loader.py +++ b/code/bolsonaro/data/dataset_loader.py @@ -32,13 +32,14 @@ class DatasetLoader(object): dataset_names = ['boston', 'iris', 'diabetes', 'digits', 'linnerud', 'wine', 'breast_cancer', 'olivetti_faces', '20newsgroups_vectorized', 'lfw_people', 'lfw_pairs', 'covtype', 'rcv1', 'california_housing', 'diamonds', 'steel-plates', - 'kr-vs-kp', 'kin8nm'] + 'kr-vs-kp', 'kin8nm', 'spambase', 'musk', 'gamma'] dataset_seed_numbers = {'boston':15, 'iris':15, 'diabetes':15, 'digits':5, 'linnerud':15, 'wine':15, 'breast_cancer':15, 'olivetti_faces':15, '20newsgroups_vectorized':3, 'lfw_people':3, 'lfw_pairs':3, 'covtype':3, 'rcv1':3, 'california_housing':3, - 'diamonds': 15, 'steel-plates':15, 'kr-vs-kp':15, 'kin8nm':15} + 'diamonds': 15, 'steel-plates': 15, 'kr-vs-kp': 15, 'kin8nm': 15, + 'spambase': 15, 'musk': 15, 'gamma': 15} @staticmethod def load(dataset_parameters): @@ -114,6 +115,15 @@ class DatasetLoader(object): elif name == 'kin8nm': X, y = fetch_openml('kin8nm', return_X_y=True) task = Task.REGRESSION + elif name == 'spambase': + dataset_loading_func = change_binary_func_openml('spambase') + task = Task.BINARYCLASSIFICATION + elif name == 'musk': + dataset_loading_func = change_binary_func_openml('musk') + task = Task.BINARYCLASSIFICATION + elif name == 'gamma': + dataset_loading_func = change_binary_func_openml('MagicTelescope') + task = Task.BINARYCLASSIFICATION else: raise ValueError("Unsupported dataset '{}'".format(name))