Skip to content
Snippets Groups Projects
Commit a826e7cc authored by Luc Giffon's avatar Luc Giffon
Browse files

Update dataset_loader.py

petit problème de nommage
parent 87115964
No related branches found
No related tags found
2 merge requests!3clean scripts,!1Luc new archi
from bolsonaro.data import Dataset
from bolsonaro.data import Task
from sklearn.datasets import load_boston, load_iris, load_diabetes, load_digits, load_linnerud, load_wine, load_breast_cancer
from sklearn.datasets import fetch_olivetti_faces, fetch_20newsgroups, \
fetch_20newsgroups_vectorized, fetch_lfw_people, fetch_lfw_pairs, \
fetch_covtype, fetch_rcv1, fetch_kddcup99, fetch_california_housing
from sklearn.model_selection import train_test_split
class DatasetLoader(object):
@staticmethod
def load_from_name(dataset_parameters):
name = dataset_parameters.name
if name == 'boston':
dataset_loading_func = load_boston
task = Task.REGRESSION
elif name == 'iris':
dataset_loading_func = load_iris
task = Task.CLASSIFICATION
elif name == 'diabetes':
dataset_loading_func = load_diabetes
task = Task.REGRESSION
elif name == 'digits':
dataset_loading_func = load_digits
task = Task.CLASSIFICATION
elif name == 'linnerud':
dataset_loading_func = load_linnerud
task = Task.REGRESSION
elif name == 'wine':
dataset_loading_func = load_wine
task = Task.CLASSIFICATION
elif name == 'breast_cancer':
dataset_loading_func = load_breast_cancer
task = Task.CLASSIFICATION
elif name == 'olivetti_faces':
dataset_loading_func = fetch_olivetti_faces
task = Task.CLASSIFICATION
elif name == '20newsgroups':
dataset_loading_func = fetch_20newsgroups
task = Task.CLASSIFICATION
elif name == '20newsgroups_vectorized':
dataset_loading_func = fetch_20newsgroups_vectorized
task = Task.CLASSIFICATION
elif name == 'lfw_people':
dataset_loading_func = fetch_lfw_people
task = Task.CLASSIFICATION
elif name == 'lfw_pairs':
dataset_loading_func = fetch_lfw_pairs
elif name == 'covtype':
dataset_loading_func = fetch_covtype
task = Task.CLASSIFICATION
elif name == 'rcv1':
dataset_loading_func = fetch_rcv1
task = Task.CLASSIFICATION
elif name == 'kddcup99':
dataset_loading_func = fetch_kddcup99
task = Task.CLASSIFICATION
elif name == 'california_housing':
dataset_loading_func = fetch_california_housing
task = Task.REGRESSION
else:
raise ValueError("Unsupported dataset '{}'".format(name))
X, y = dataset_loading_func(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y,
test_size=dataset_parameters.test_size,
random_state=dataset_parameters.random_state)
X_train, X_dev, y_train, y_dev = train_test_split(X_train, y_train,
test_size=dataset_parameters.dev_size,
random_state=dataset_parameters.random_state)
# TODO
if dataset_parameters.normalize:
pass
return Dataset(task, dataset_parameters, X_train,
X_dev, X_test, y_train, y_dev, y_test)
from bolsonaro.data.dataset import Dataset
from bolsonaro.data.task import Task
from sklearn.datasets import load_boston, load_iris, load_diabetes, load_digits, load_linnerud, load_wine, load_breast_cancer
from sklearn.datasets import fetch_olivetti_faces, fetch_20newsgroups, \
fetch_20newsgroups_vectorized, fetch_lfw_people, fetch_lfw_pairs, \
fetch_covtype, fetch_rcv1, fetch_kddcup99, fetch_california_housing
from sklearn.model_selection import train_test_split
class DatasetLoader(object):
@staticmethod
def load_from_name(dataset_parameters):
name = dataset_parameters.name
if name == 'boston':
dataset_loading_func = load_boston
task = Task.REGRESSION
elif name == 'iris':
dataset_loading_func = load_iris
task = Task.CLASSIFICATION
elif name == 'diabetes':
dataset_loading_func = load_diabetes
task = Task.REGRESSION
elif name == 'digits':
dataset_loading_func = load_digits
task = Task.CLASSIFICATION
elif name == 'linnerud':
dataset_loading_func = load_linnerud
task = Task.REGRESSION
elif name == 'wine':
dataset_loading_func = load_wine
task = Task.CLASSIFICATION
elif name == 'breast_cancer':
dataset_loading_func = load_breast_cancer
task = Task.CLASSIFICATION
elif name == 'olivetti_faces':
dataset_loading_func = fetch_olivetti_faces
task = Task.CLASSIFICATION
elif name == '20newsgroups':
dataset_loading_func = fetch_20newsgroups
task = Task.CLASSIFICATION
elif name == '20newsgroups_vectorized':
dataset_loading_func = fetch_20newsgroups_vectorized
task = Task.CLASSIFICATION
elif name == 'lfw_people':
dataset_loading_func = fetch_lfw_people
task = Task.CLASSIFICATION
elif name == 'lfw_pairs':
dataset_loading_func = fetch_lfw_pairs
elif name == 'covtype':
dataset_loading_func = fetch_covtype
task = Task.CLASSIFICATION
elif name == 'rcv1':
dataset_loading_func = fetch_rcv1
task = Task.CLASSIFICATION
elif name == 'kddcup99':
dataset_loading_func = fetch_kddcup99
task = Task.CLASSIFICATION
elif name == 'california_housing':
dataset_loading_func = fetch_california_housing
task = Task.REGRESSION
else:
raise ValueError("Unsupported dataset '{}'".format(name))
X, y = dataset_loading_func(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y,
test_size=dataset_parameters.test_size,
random_state=dataset_parameters.random_state)
X_train, X_dev, y_train, y_dev = train_test_split(X_train, y_train,
test_size=dataset_parameters.dev_size,
random_state=dataset_parameters.random_state)
# TODO
if dataset_parameters.normalize:
pass
return Dataset(task, dataset_parameters, X_train,
X_dev, X_test, y_train, y_dev, y_test)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment