Skip to content
Snippets Groups Projects
Commit 61a8e5a9 authored by Léo Bouscarrat's avatar Léo Bouscarrat
Browse files

Adding dataset

parent 28c3d874
Branches
No related tags found
1 merge request!15Resolve "Adding new datasets"
This commit is part of merge request !15. Comments created here will be created in the context of that merge request.
from bolsonaro.data.dataset import Dataset
from bolsonaro.data.dataset_parameters import DatasetParameters
from bolsonaro.data.task import Task
from bolsonaro.utils import change_binary_func_load
from bolsonaro.utils import change_binary_func_load, change_binary_func_openml
from sklearn.datasets import load_boston, load_iris, load_diabetes, \
load_digits, load_linnerud, load_wine, load_breast_cancer
from sklearn.datasets import fetch_olivetti_faces, fetch_20newsgroups, \
fetch_20newsgroups_vectorized, fetch_lfw_people, fetch_lfw_pairs, \
fetch_covtype, fetch_rcv1, fetch_kddcup99, fetch_california_housing
fetch_covtype, fetch_rcv1, fetch_kddcup99, fetch_california_housing, \
fetch_openml
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
import random
......@@ -103,6 +104,12 @@ class DatasetLoader(object):
df['clarity'] = label_clarity.fit_transform(df['clarity'])
X, y = df.drop(['price'], axis=1), df['price']
task = Task.REGRESSION
elif name == 'steel-plates':
dataset_loading_func = change_binary_func_openml('steel-plates-fault')
task = Task.BINARYCLASSIFICATION
elif name == 'kr-vs-kp':
dataset_loading_func = change_binary_func_openml('kr-vs-kp')
task = Task.BINARYCLASSIFICATION
else:
raise ValueError("Unsupported dataset '{}'".format(name))
......
......@@ -33,6 +33,8 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta):
# sklearn baseestimator api methods
def fit(self, X_forest, y_forest, X_omp, y_omp):
print(y_forest.shape)
print(set([type(y) for y in y_forest]))
self._base_forest_estimator.fit(X_forest, y_forest)
self._extract_subforest(X_omp, y_omp) # type: OrthogonalMatchingPursuit
return self
......
......@@ -5,6 +5,8 @@ from copy import deepcopy
import contextlib
import joblib
from sklearn.datasets import fetch_openml
def resolve_experiment_id(models_dir):
"""
......@@ -78,6 +80,19 @@ def change_binary_func_load(base_load_function):
return X, y
return func_load
def change_binary_func_openml(dataset_name):
def func_load(return_X_y=True, random_state=None):
if random_state:
X, y = fetch_openml(dataset_name, return_X_y=return_X_y, random_state=random_state)
else:
X, y = fetch_openml(dataset_name, return_X_y=return_X_y)
possible_classes = sorted(set(y))
assert len(possible_classes) == 2, "Function change binary_func_load only work for binary classfication"
y = binarize_class_data(y, possible_classes[-1])
y = y.astype('int')
return X, y
return func_load
@contextlib.contextmanager
def tqdm_joblib(tqdm_object):
"""Context manager to patch joblib to report into tqdm progress bar given as argument"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment