From 33453392567604381631f440e02c5c53b2d88fad Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?L=C3=A9o=20Bouscarrat?= <leo.bouscarrat@euranova.eu>
Date: Thu, 5 Mar 2020 14:05:08 +0100
Subject: [PATCH] Add 3 new datasets: gamma, mush, spambase

---
 code/bolsonaro/data/dataset_loader.py | 14 ++++++++++++--
 1 file changed, 12 insertions(+), 2 deletions(-)

diff --git a/code/bolsonaro/data/dataset_loader.py b/code/bolsonaro/data/dataset_loader.py
index 7d05f20..0bebeaa 100644
--- a/code/bolsonaro/data/dataset_loader.py
+++ b/code/bolsonaro/data/dataset_loader.py
@@ -32,13 +32,14 @@ class DatasetLoader(object):
     dataset_names = ['boston', 'iris', 'diabetes', 'digits', 'linnerud', 'wine',
         'breast_cancer', 'olivetti_faces', '20newsgroups_vectorized', 'lfw_people',
         'lfw_pairs', 'covtype', 'rcv1', 'california_housing', 'diamonds', 'steel-plates',
-        'kr-vs-kp', 'kin8nm']
+        'kr-vs-kp', 'kin8nm', 'spambase', 'musk', 'gamma']
 
     dataset_seed_numbers = {'boston':15, 'iris':15, 'diabetes':15, 'digits':5,
         'linnerud':15, 'wine':15, 'breast_cancer':15, 'olivetti_faces':15,
         '20newsgroups_vectorized':3, 'lfw_people':3,
         'lfw_pairs':3, 'covtype':3, 'rcv1':3, 'california_housing':3,
-        'diamonds': 15, 'steel-plates':15, 'kr-vs-kp':15, 'kin8nm':15}
+        'diamonds': 15, 'steel-plates': 15, 'kr-vs-kp': 15, 'kin8nm': 15,
+        'spambase': 15, 'musk': 15, 'gamma': 15}
 
     @staticmethod
     def load(dataset_parameters):
@@ -114,6 +115,15 @@ class DatasetLoader(object):
         elif name == 'kin8nm':
             X, y = fetch_openml('kin8nm', return_X_y=True)
             task = Task.REGRESSION
+        elif name == 'spambase':
+            dataset_loading_func = change_binary_func_openml('spambase')
+            task = Task.BINARYCLASSIFICATION
+        elif name == 'musk':
+            dataset_loading_func = change_binary_func_openml('musk')
+            task = Task.BINARYCLASSIFICATION
+        elif name == 'gamma':
+            dataset_loading_func = change_binary_func_openml('MagicTelescope')
+            task = Task.BINARYCLASSIFICATION
         else:
             raise ValueError("Unsupported dataset '{}'".format(name))
 
-- 
GitLab