From d70fad15eb55f9559072a4dcad6bb370e58bfb52 Mon Sep 17 00:00:00 2001
From: Charly Lamothe <charly.lamothe@univ-amu.fr>
Date: Fri, 29 Nov 2019 22:56:05 +0100
Subject: [PATCH] Fix most of the missing dataset fetchers. Remainings:
 20newsgroups, 20newsgroups_vectorized, kddcup99

---
 code/bolsonaro/data/dataset_loader.py | 31 ++++++++++++++++-----------
 code/bolsonaro/utils.py               |  7 ++++--
 2 files changed, 24 insertions(+), 14 deletions(-)

diff --git a/code/bolsonaro/data/dataset_loader.py b/code/bolsonaro/data/dataset_loader.py
index b22ecaa..ee15a72 100644
--- a/code/bolsonaro/data/dataset_loader.py
+++ b/code/bolsonaro/data/dataset_loader.py
@@ -16,6 +16,7 @@ class DatasetLoader(object):
     @staticmethod
     def load(dataset_parameters):
         name = dataset_parameters.name
+        X, y = None, None
         if name == 'boston':
             dataset_loading_func = load_boston
             task = Task.REGRESSION
@@ -37,37 +38,43 @@ class DatasetLoader(object):
         elif name == 'breast_cancer':
             dataset_loading_func = change_binary_func_load(load_breast_cancer)
             task = Task.BINARYCLASSIFICATION
-        elif name == 'olivetti_faces':  # bug (no return X_y)
-            dataset_loading_func = fetch_olivetti_faces
+        elif name == 'olivetti_faces':
+            data = fetch_olivetti_faces(random_state=dataset_parameters.random_state, shuffle=True)
             task = Task.MULTICLASSIFICATION
-        elif name == '20newsgroups':  # bug (no return X_y)
-            dataset_loading_func = fetch_20newsgroups
+            X, y = data.data, data.target
+        elif name == '20newsgroups':
+            data = fetch_20newsgroups(random_state=dataset_parameters.random_state, shuffle=True)
+            #X, y = 
             task = Task.MULTICLASSIFICATION
         elif name == '20newsgroups_vectorized':
             dataset_loading_func = fetch_20newsgroups_vectorized
             task = Task.MULTICLASSIFICATION
-        elif name == 'lfw_people':  # needs PIL (image dataset)
-            dataset_loading_func = fetch_lfw_people
+        elif name == 'lfw_people':
+            data = fetch_lfw_people()
+            X, y = data.data, data.target
             task = Task.MULTICLASSIFICATION
         elif name == 'lfw_pairs':
-            dataset_loading_func = fetch_lfw_pairs
+            data = fetch_lfw_pairs()
+            X, y = data.data, data.target
             task = Task.MULTICLASSIFICATION
         elif name == 'covtype':
-            dataset_loading_func = fetch_covtype
+            X, y = fetch_covtype(random_state=dataset_parameters.random_state, shuffle=True, return_X_y=True)
             task = Task.MULTICLASSIFICATION
         elif name == 'rcv1':
-            dataset_loading_func = fetch_rcv1
+            X, y = fetch_rcv1(random_state=dataset_parameters.random_state, shuffle=True, return_X_y=True)
             task = Task.MULTICLASSIFICATION
         elif name == 'kddcup99':
-            dataset_loading_func = fetch_kddcup99
+            X, y = fetch_kddcup99(random_state=dataset_parameters.random_state, shuffle=True, return_X_y=True)
             task = Task.MULTICLASSIFICATION
         elif name == 'california_housing':
-            dataset_loading_func = fetch_california_housing
+            X, y = fetch_california_housing(return_X_y=True)
             task = Task.REGRESSION
         else:
             raise ValueError("Unsupported dataset '{}'".format(name))
 
-        X, y = dataset_loading_func(return_X_y=True)
+        if X is None:
+            X, y = dataset_loading_func()
+
         X_train, X_test, y_train, y_test = train_test_split(X, y,
             test_size=dataset_parameters.test_size,
             random_state=dataset_parameters.random_state)
diff --git a/code/bolsonaro/utils.py b/code/bolsonaro/utils.py
index d7509ad..10ea769 100644
--- a/code/bolsonaro/utils.py
+++ b/code/bolsonaro/utils.py
@@ -66,8 +66,11 @@ def binarize_class_data(data, class_pos, inplace=True):
     return data
 
 def change_binary_func_load(base_load_function):
-    def func_load(return_X_y):
-        X, y = base_load_function(return_X_y=return_X_y)
+    def func_load(return_X_y, random_state=None):
+        if random_state:
+            X, y = base_load_function(return_X_y=return_X_y, random_state=random_state)
+        else:
+            X, y = base_load_function(return_X_y=return_X_y)
         possible_classes = sorted(set(y))
         assert len(possible_classes) == 2, "Function change binary_func_load only work for binary classfication"
         y = binarize_class_data(y, possible_classes[-1])
-- 
GitLab