diff --git a/code/bolsonaro/data/dataset_loader.py b/code/bolsonaro/data/dataset_loader.py index 892104711701cd430962a3832566da6e62b83a38..2295093783839221523aaba715bc3ecf78082352 100644 --- a/code/bolsonaro/data/dataset_loader.py +++ b/code/bolsonaro/data/dataset_loader.py @@ -11,6 +11,7 @@ from sklearn.datasets import fetch_olivetti_faces, fetch_20newsgroups, \ from sklearn.model_selection import train_test_split from sklearn import preprocessing import random +import pandas as pd class DatasetLoader(object): @@ -29,12 +30,13 @@ class DatasetLoader(object): dataset_names = ['boston', 'iris', 'diabetes', 'digits', 'linnerud', 'wine', 'breast_cancer', 'olivetti_faces', '20newsgroups_vectorized', 'lfw_people', - 'lfw_pairs', 'covtype', 'rcv1', 'california_housing'] + 'lfw_pairs', 'covtype', 'rcv1', 'california_housing', 'diamonds'] dataset_seed_numbers = {'boston':15, 'iris':15, 'diabetes':15, 'digits':5, 'linnerud':15, 'wine':15, 'breast_cancer':15, 'olivetti_faces':15, '20newsgroups_vectorized':3, 'lfw_people':3, - 'lfw_pairs':3, 'covtype':3, 'rcv1':3, 'california_housing':3} + 'lfw_pairs':3, 'covtype':3, 'rcv1':3, 'california_housing':3, + 'diamonds': 15} @staticmethod def load(dataset_parameters): @@ -86,6 +88,21 @@ class DatasetLoader(object): elif name == 'california_housing': X, y = fetch_california_housing(return_X_y=True) task = Task.REGRESSION + elif name == 'diamonds': + # TODO: make a proper fetcher instead of the following code + from sklearn.preprocessing import LabelEncoder + df = pd.read_csv('data/diamonds.csv') + df.drop(['Unnamed: 0'], axis=1 , inplace=True) + df = df[(df[['x','y','z']] != 0).all(axis=1)] + df.drop(['x','y','z'], axis=1, inplace= True) + label_cut = LabelEncoder() + label_color = LabelEncoder() + label_clarity = LabelEncoder() + df['cut'] = label_cut.fit_transform(df['cut']) + df['color'] = label_color.fit_transform(df['color']) + df['clarity'] = label_clarity.fit_transform(df['clarity']) + X, y = df.drop(['price'], axis=1), df['price'] + task = Task.REGRESSION else: raise ValueError("Unsupported dataset '{}'".format(name))