Skip to content
Snippets Groups Projects
Commit 61a8e5a9 authored by Léo Bouscarrat's avatar Léo Bouscarrat
Browse files

Adding dataset

parent 28c3d874
Branches
No related tags found
1 merge request!15Resolve "Adding new datasets"
from bolsonaro.data.dataset import Dataset from bolsonaro.data.dataset import Dataset
from bolsonaro.data.dataset_parameters import DatasetParameters from bolsonaro.data.dataset_parameters import DatasetParameters
from bolsonaro.data.task import Task from bolsonaro.data.task import Task
from bolsonaro.utils import change_binary_func_load from bolsonaro.utils import change_binary_func_load, change_binary_func_openml
from sklearn.datasets import load_boston, load_iris, load_diabetes, \ from sklearn.datasets import load_boston, load_iris, load_diabetes, \
load_digits, load_linnerud, load_wine, load_breast_cancer load_digits, load_linnerud, load_wine, load_breast_cancer
from sklearn.datasets import fetch_olivetti_faces, fetch_20newsgroups, \ from sklearn.datasets import fetch_olivetti_faces, fetch_20newsgroups, \
fetch_20newsgroups_vectorized, fetch_lfw_people, fetch_lfw_pairs, \ fetch_20newsgroups_vectorized, fetch_lfw_people, fetch_lfw_pairs, \
fetch_covtype, fetch_rcv1, fetch_kddcup99, fetch_california_housing fetch_covtype, fetch_rcv1, fetch_kddcup99, fetch_california_housing, \
fetch_openml
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from sklearn import preprocessing from sklearn import preprocessing
import random import random
...@@ -103,6 +104,12 @@ class DatasetLoader(object): ...@@ -103,6 +104,12 @@ class DatasetLoader(object):
df['clarity'] = label_clarity.fit_transform(df['clarity']) df['clarity'] = label_clarity.fit_transform(df['clarity'])
X, y = df.drop(['price'], axis=1), df['price'] X, y = df.drop(['price'], axis=1), df['price']
task = Task.REGRESSION task = Task.REGRESSION
elif name == 'steel-plates':
dataset_loading_func = change_binary_func_openml('steel-plates-fault')
task = Task.BINARYCLASSIFICATION
elif name == 'kr-vs-kp':
dataset_loading_func = change_binary_func_openml('kr-vs-kp')
task = Task.BINARYCLASSIFICATION
else: else:
raise ValueError("Unsupported dataset '{}'".format(name)) raise ValueError("Unsupported dataset '{}'".format(name))
......
...@@ -33,6 +33,8 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta): ...@@ -33,6 +33,8 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta):
# sklearn baseestimator api methods # sklearn baseestimator api methods
def fit(self, X_forest, y_forest, X_omp, y_omp): def fit(self, X_forest, y_forest, X_omp, y_omp):
print(y_forest.shape)
print(set([type(y) for y in y_forest]))
self._base_forest_estimator.fit(X_forest, y_forest) self._base_forest_estimator.fit(X_forest, y_forest)
self._extract_subforest(X_omp, y_omp) # type: OrthogonalMatchingPursuit self._extract_subforest(X_omp, y_omp) # type: OrthogonalMatchingPursuit
return self return self
......
...@@ -5,6 +5,8 @@ from copy import deepcopy ...@@ -5,6 +5,8 @@ from copy import deepcopy
import contextlib import contextlib
import joblib import joblib
from sklearn.datasets import fetch_openml
def resolve_experiment_id(models_dir): def resolve_experiment_id(models_dir):
""" """
...@@ -78,6 +80,19 @@ def change_binary_func_load(base_load_function): ...@@ -78,6 +80,19 @@ def change_binary_func_load(base_load_function):
return X, y return X, y
return func_load return func_load
def change_binary_func_openml(dataset_name):
def func_load(return_X_y=True, random_state=None):
if random_state:
X, y = fetch_openml(dataset_name, return_X_y=return_X_y, random_state=random_state)
else:
X, y = fetch_openml(dataset_name, return_X_y=return_X_y)
possible_classes = sorted(set(y))
assert len(possible_classes) == 2, "Function change binary_func_load only work for binary classfication"
y = binarize_class_data(y, possible_classes[-1])
y = y.astype('int')
return X, y
return func_load
@contextlib.contextmanager @contextlib.contextmanager
def tqdm_joblib(tqdm_object): def tqdm_joblib(tqdm_object):
"""Context manager to patch joblib to report into tqdm progress bar given as argument""" """Context manager to patch joblib to report into tqdm progress bar given as argument"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment