From 61a8e5a9ef14919cdb2693244bc37f65008d8921 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?L=C3=A9o=20Bouscarrat?= <leo.bouscarrat@euranova.eu>
Date: Fri, 28 Feb 2020 18:28:02 +0100
Subject: [PATCH] Adding dataset

---
 code/bolsonaro/data/dataset_loader.py | 11 +++++++++--
 code/bolsonaro/models/omp_forest.py   |  2 ++
 code/bolsonaro/utils.py               | 15 +++++++++++++++
 3 files changed, 26 insertions(+), 2 deletions(-)

diff --git a/code/bolsonaro/data/dataset_loader.py b/code/bolsonaro/data/dataset_loader.py
index 2295093..a80d3ca 100644
--- a/code/bolsonaro/data/dataset_loader.py
+++ b/code/bolsonaro/data/dataset_loader.py
@@ -1,13 +1,14 @@
 from bolsonaro.data.dataset import Dataset
 from bolsonaro.data.dataset_parameters import DatasetParameters
 from bolsonaro.data.task import Task
-from bolsonaro.utils import change_binary_func_load
+from bolsonaro.utils import change_binary_func_load, change_binary_func_openml
 
 from sklearn.datasets import load_boston, load_iris, load_diabetes, \
     load_digits, load_linnerud, load_wine, load_breast_cancer
 from sklearn.datasets import fetch_olivetti_faces, fetch_20newsgroups, \
     fetch_20newsgroups_vectorized, fetch_lfw_people, fetch_lfw_pairs, \
-    fetch_covtype, fetch_rcv1, fetch_kddcup99, fetch_california_housing
+    fetch_covtype, fetch_rcv1, fetch_kddcup99, fetch_california_housing, \
+    fetch_openml
 from sklearn.model_selection import train_test_split
 from sklearn import preprocessing
 import random
@@ -103,6 +104,12 @@ class DatasetLoader(object):
             df['clarity'] = label_clarity.fit_transform(df['clarity'])
             X, y = df.drop(['price'], axis=1), df['price']
             task = Task.REGRESSION
+        elif name == 'steel-plates':
+            dataset_loading_func = change_binary_func_openml('steel-plates-fault')
+            task = Task.BINARYCLASSIFICATION
+        elif name == 'kr-vs-kp':
+            dataset_loading_func = change_binary_func_openml('kr-vs-kp')
+            task = Task.BINARYCLASSIFICATION
         else:
             raise ValueError("Unsupported dataset '{}'".format(name))
 
diff --git a/code/bolsonaro/models/omp_forest.py b/code/bolsonaro/models/omp_forest.py
index b5339f8..64167de 100644
--- a/code/bolsonaro/models/omp_forest.py
+++ b/code/bolsonaro/models/omp_forest.py
@@ -33,6 +33,8 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta):
 
     # sklearn baseestimator api methods
     def fit(self, X_forest, y_forest, X_omp, y_omp):
+        print(y_forest.shape)
+        print(set([type(y) for y in y_forest]))
         self._base_forest_estimator.fit(X_forest, y_forest)
         self._extract_subforest(X_omp, y_omp) # type: OrthogonalMatchingPursuit
         return self
diff --git a/code/bolsonaro/utils.py b/code/bolsonaro/utils.py
index daa695d..f594f92 100644
--- a/code/bolsonaro/utils.py
+++ b/code/bolsonaro/utils.py
@@ -5,6 +5,8 @@ from copy import deepcopy
 import contextlib
 import joblib
 
+from sklearn.datasets import fetch_openml
+
 
 def resolve_experiment_id(models_dir):
     """
@@ -78,6 +80,19 @@ def change_binary_func_load(base_load_function):
         return X, y
     return func_load
 
+def change_binary_func_openml(dataset_name):
+    def func_load(return_X_y=True, random_state=None):
+        if random_state:
+            X, y = fetch_openml(dataset_name, return_X_y=return_X_y, random_state=random_state)
+        else:
+            X, y = fetch_openml(dataset_name, return_X_y=return_X_y)
+        possible_classes = sorted(set(y))
+        assert len(possible_classes) == 2, "Function change binary_func_load only work for binary classfication"
+        y = binarize_class_data(y, possible_classes[-1])
+        y = y.astype('int')
+        return X, y
+    return func_load
+
 @contextlib.contextmanager
 def tqdm_joblib(tqdm_object):
     """Context manager to patch joblib to report into tqdm progress bar given as argument"""
-- 
GitLab