From 61a8e5a9ef14919cdb2693244bc37f65008d8921 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20Bouscarrat?= <leo.bouscarrat@euranova.eu> Date: Fri, 28 Feb 2020 18:28:02 +0100 Subject: [PATCH] Adding dataset --- code/bolsonaro/data/dataset_loader.py | 11 +++++++++-- code/bolsonaro/models/omp_forest.py | 2 ++ code/bolsonaro/utils.py | 15 +++++++++++++++ 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/code/bolsonaro/data/dataset_loader.py b/code/bolsonaro/data/dataset_loader.py index 2295093..a80d3ca 100644 --- a/code/bolsonaro/data/dataset_loader.py +++ b/code/bolsonaro/data/dataset_loader.py @@ -1,13 +1,14 @@ from bolsonaro.data.dataset import Dataset from bolsonaro.data.dataset_parameters import DatasetParameters from bolsonaro.data.task import Task -from bolsonaro.utils import change_binary_func_load +from bolsonaro.utils import change_binary_func_load, change_binary_func_openml from sklearn.datasets import load_boston, load_iris, load_diabetes, \ load_digits, load_linnerud, load_wine, load_breast_cancer from sklearn.datasets import fetch_olivetti_faces, fetch_20newsgroups, \ fetch_20newsgroups_vectorized, fetch_lfw_people, fetch_lfw_pairs, \ - fetch_covtype, fetch_rcv1, fetch_kddcup99, fetch_california_housing + fetch_covtype, fetch_rcv1, fetch_kddcup99, fetch_california_housing, \ + fetch_openml from sklearn.model_selection import train_test_split from sklearn import preprocessing import random @@ -103,6 +104,12 @@ class DatasetLoader(object): df['clarity'] = label_clarity.fit_transform(df['clarity']) X, y = df.drop(['price'], axis=1), df['price'] task = Task.REGRESSION + elif name == 'steel-plates': + dataset_loading_func = change_binary_func_openml('steel-plates-fault') + task = Task.BINARYCLASSIFICATION + elif name == 'kr-vs-kp': + dataset_loading_func = change_binary_func_openml('kr-vs-kp') + task = Task.BINARYCLASSIFICATION else: raise ValueError("Unsupported dataset '{}'".format(name)) diff --git a/code/bolsonaro/models/omp_forest.py b/code/bolsonaro/models/omp_forest.py index b5339f8..64167de 100644 --- a/code/bolsonaro/models/omp_forest.py +++ b/code/bolsonaro/models/omp_forest.py @@ -33,6 +33,8 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta): # sklearn baseestimator api methods def fit(self, X_forest, y_forest, X_omp, y_omp): + print(y_forest.shape) + print(set([type(y) for y in y_forest])) self._base_forest_estimator.fit(X_forest, y_forest) self._extract_subforest(X_omp, y_omp) # type: OrthogonalMatchingPursuit return self diff --git a/code/bolsonaro/utils.py b/code/bolsonaro/utils.py index daa695d..f594f92 100644 --- a/code/bolsonaro/utils.py +++ b/code/bolsonaro/utils.py @@ -5,6 +5,8 @@ from copy import deepcopy import contextlib import joblib +from sklearn.datasets import fetch_openml + def resolve_experiment_id(models_dir): """ @@ -78,6 +80,19 @@ def change_binary_func_load(base_load_function): return X, y return func_load +def change_binary_func_openml(dataset_name): + def func_load(return_X_y=True, random_state=None): + if random_state: + X, y = fetch_openml(dataset_name, return_X_y=return_X_y, random_state=random_state) + else: + X, y = fetch_openml(dataset_name, return_X_y=return_X_y) + possible_classes = sorted(set(y)) + assert len(possible_classes) == 2, "Function change binary_func_load only work for binary classfication" + y = binarize_class_data(y, possible_classes[-1]) + y = y.astype('int') + return X, y + return func_load + @contextlib.contextmanager def tqdm_joblib(tqdm_object): """Context manager to patch joblib to report into tqdm progress bar given as argument""" -- GitLab