diff --git a/code/bolsonaro/data/dataset_loader.py b/code/bolsonaro/data/dataset_loader.py index 2295093783839221523aaba715bc3ecf78082352..a80d3ca5132427202a56a63ce94659b5c3d713e3 100644 --- a/code/bolsonaro/data/dataset_loader.py +++ b/code/bolsonaro/data/dataset_loader.py @@ -1,13 +1,14 @@ from bolsonaro.data.dataset import Dataset from bolsonaro.data.dataset_parameters import DatasetParameters from bolsonaro.data.task import Task -from bolsonaro.utils import change_binary_func_load +from bolsonaro.utils import change_binary_func_load, change_binary_func_openml from sklearn.datasets import load_boston, load_iris, load_diabetes, \ load_digits, load_linnerud, load_wine, load_breast_cancer from sklearn.datasets import fetch_olivetti_faces, fetch_20newsgroups, \ fetch_20newsgroups_vectorized, fetch_lfw_people, fetch_lfw_pairs, \ - fetch_covtype, fetch_rcv1, fetch_kddcup99, fetch_california_housing + fetch_covtype, fetch_rcv1, fetch_kddcup99, fetch_california_housing, \ + fetch_openml from sklearn.model_selection import train_test_split from sklearn import preprocessing import random @@ -103,6 +104,12 @@ class DatasetLoader(object): df['clarity'] = label_clarity.fit_transform(df['clarity']) X, y = df.drop(['price'], axis=1), df['price'] task = Task.REGRESSION + elif name == 'steel-plates': + dataset_loading_func = change_binary_func_openml('steel-plates-fault') + task = Task.BINARYCLASSIFICATION + elif name == 'kr-vs-kp': + dataset_loading_func = change_binary_func_openml('kr-vs-kp') + task = Task.BINARYCLASSIFICATION else: raise ValueError("Unsupported dataset '{}'".format(name)) diff --git a/code/bolsonaro/models/omp_forest.py b/code/bolsonaro/models/omp_forest.py index b5339f8b471cddbd4a653e42c3b6604757c95ed6..64167dea110759c73dfd1f6cc2989a7fdb2d027d 100644 --- a/code/bolsonaro/models/omp_forest.py +++ b/code/bolsonaro/models/omp_forest.py @@ -33,6 +33,8 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta): # sklearn baseestimator api methods def fit(self, X_forest, y_forest, X_omp, y_omp): + print(y_forest.shape) + print(set([type(y) for y in y_forest])) self._base_forest_estimator.fit(X_forest, y_forest) self._extract_subforest(X_omp, y_omp) # type: OrthogonalMatchingPursuit return self diff --git a/code/bolsonaro/utils.py b/code/bolsonaro/utils.py index daa695d3f047bca2f2b026d0711767b1c2bef128..f594f9274a9527e43266622470281d5757868a96 100644 --- a/code/bolsonaro/utils.py +++ b/code/bolsonaro/utils.py @@ -5,6 +5,8 @@ from copy import deepcopy import contextlib import joblib +from sklearn.datasets import fetch_openml + def resolve_experiment_id(models_dir): """ @@ -78,6 +80,19 @@ def change_binary_func_load(base_load_function): return X, y return func_load +def change_binary_func_openml(dataset_name): + def func_load(return_X_y=True, random_state=None): + if random_state: + X, y = fetch_openml(dataset_name, return_X_y=return_X_y, random_state=random_state) + else: + X, y = fetch_openml(dataset_name, return_X_y=return_X_y) + possible_classes = sorted(set(y)) + assert len(possible_classes) == 2, "Function change binary_func_load only work for binary classfication" + y = binarize_class_data(y, possible_classes[-1]) + y = y.astype('int') + return X, y + return func_load + @contextlib.contextmanager def tqdm_joblib(tqdm_object): """Context manager to patch joblib to report into tqdm progress bar given as argument"""