Skip to content
Snippets Groups Projects
Commit dcbafbb0 authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Add quick working version of diamonds dataset to test that the improvement of...

Add quick working version of diamonds dataset to test that the improvement of OMP are consistent in other regression tasks.
parent 1adfbf0a
No related branches found
No related tags found
1 merge request!9Resolve "Experiment pipeline"
...@@ -11,6 +11,7 @@ from sklearn.datasets import fetch_olivetti_faces, fetch_20newsgroups, \ ...@@ -11,6 +11,7 @@ from sklearn.datasets import fetch_olivetti_faces, fetch_20newsgroups, \
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from sklearn import preprocessing from sklearn import preprocessing
import random import random
import pandas as pd
class DatasetLoader(object): class DatasetLoader(object):
...@@ -29,12 +30,13 @@ class DatasetLoader(object): ...@@ -29,12 +30,13 @@ class DatasetLoader(object):
dataset_names = ['boston', 'iris', 'diabetes', 'digits', 'linnerud', 'wine', dataset_names = ['boston', 'iris', 'diabetes', 'digits', 'linnerud', 'wine',
'breast_cancer', 'olivetti_faces', '20newsgroups_vectorized', 'lfw_people', 'breast_cancer', 'olivetti_faces', '20newsgroups_vectorized', 'lfw_people',
'lfw_pairs', 'covtype', 'rcv1', 'california_housing'] 'lfw_pairs', 'covtype', 'rcv1', 'california_housing', 'diamonds']
dataset_seed_numbers = {'boston':15, 'iris':15, 'diabetes':15, 'digits':5, dataset_seed_numbers = {'boston':15, 'iris':15, 'diabetes':15, 'digits':5,
'linnerud':15, 'wine':15, 'breast_cancer':15, 'olivetti_faces':15, 'linnerud':15, 'wine':15, 'breast_cancer':15, 'olivetti_faces':15,
'20newsgroups_vectorized':3, 'lfw_people':3, '20newsgroups_vectorized':3, 'lfw_people':3,
'lfw_pairs':3, 'covtype':3, 'rcv1':3, 'california_housing':3} 'lfw_pairs':3, 'covtype':3, 'rcv1':3, 'california_housing':3,
'diamonds': 15}
@staticmethod @staticmethod
def load(dataset_parameters): def load(dataset_parameters):
...@@ -86,6 +88,21 @@ class DatasetLoader(object): ...@@ -86,6 +88,21 @@ class DatasetLoader(object):
elif name == 'california_housing': elif name == 'california_housing':
X, y = fetch_california_housing(return_X_y=True) X, y = fetch_california_housing(return_X_y=True)
task = Task.REGRESSION task = Task.REGRESSION
elif name == 'diamonds':
# TODO: make a proper fetcher instead of the following code
from sklearn.preprocessing import LabelEncoder
df = pd.read_csv('data/diamonds.csv')
df.drop(['Unnamed: 0'], axis=1 , inplace=True)
df = df[(df[['x','y','z']] != 0).all(axis=1)]
df.drop(['x','y','z'], axis=1, inplace= True)
label_cut = LabelEncoder()
label_color = LabelEncoder()
label_clarity = LabelEncoder()
df['cut'] = label_cut.fit_transform(df['cut'])
df['color'] = label_color.fit_transform(df['color'])
df['clarity'] = label_clarity.fit_transform(df['clarity'])
X, y = df.drop(['price'], axis=1), df['price']
task = Task.REGRESSION
else: else:
raise ValueError("Unsupported dataset '{}'".format(name)) raise ValueError("Unsupported dataset '{}'".format(name))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment