diff --git a/code/bolsonaro/data/dataset_loader.py b/code/bolsonaro/data/dataset_loader.py index dc6382c364d9ef4518b877a5035124f09f3e9bfb..6ad4b1f769d35b67b9ebcec6dae6b03ed68607e7 100644 --- a/code/bolsonaro/data/dataset_loader.py +++ b/code/bolsonaro/data/dataset_loader.py @@ -1,79 +1,79 @@ -from bolsonaro.data import Dataset -from bolsonaro.data import Task - -from sklearn.datasets import load_boston, load_iris, load_diabetes, load_digits, load_linnerud, load_wine, load_breast_cancer -from sklearn.datasets import fetch_olivetti_faces, fetch_20newsgroups, \ - fetch_20newsgroups_vectorized, fetch_lfw_people, fetch_lfw_pairs, \ - fetch_covtype, fetch_rcv1, fetch_kddcup99, fetch_california_housing -from sklearn.model_selection import train_test_split - - -class DatasetLoader(object): - - @staticmethod - def load_from_name(dataset_parameters): - name = dataset_parameters.name - if name == 'boston': - dataset_loading_func = load_boston - task = Task.REGRESSION - elif name == 'iris': - dataset_loading_func = load_iris - task = Task.CLASSIFICATION - elif name == 'diabetes': - dataset_loading_func = load_diabetes - task = Task.REGRESSION - elif name == 'digits': - dataset_loading_func = load_digits - task = Task.CLASSIFICATION - elif name == 'linnerud': - dataset_loading_func = load_linnerud - task = Task.REGRESSION - elif name == 'wine': - dataset_loading_func = load_wine - task = Task.CLASSIFICATION - elif name == 'breast_cancer': - dataset_loading_func = load_breast_cancer - task = Task.CLASSIFICATION - elif name == 'olivetti_faces': - dataset_loading_func = fetch_olivetti_faces - task = Task.CLASSIFICATION - elif name == '20newsgroups': - dataset_loading_func = fetch_20newsgroups - task = Task.CLASSIFICATION - elif name == '20newsgroups_vectorized': - dataset_loading_func = fetch_20newsgroups_vectorized - task = Task.CLASSIFICATION - elif name == 'lfw_people': - dataset_loading_func = fetch_lfw_people - task = Task.CLASSIFICATION - elif name == 'lfw_pairs': - dataset_loading_func = fetch_lfw_pairs - elif name == 'covtype': - dataset_loading_func = fetch_covtype - task = Task.CLASSIFICATION - elif name == 'rcv1': - dataset_loading_func = fetch_rcv1 - task = Task.CLASSIFICATION - elif name == 'kddcup99': - dataset_loading_func = fetch_kddcup99 - task = Task.CLASSIFICATION - elif name == 'california_housing': - dataset_loading_func = fetch_california_housing - task = Task.REGRESSION - else: - raise ValueError("Unsupported dataset '{}'".format(name)) - - X, y = dataset_loading_func(return_X_y=True) - X_train, X_test, y_train, y_test = train_test_split(X, y, - test_size=dataset_parameters.test_size, - random_state=dataset_parameters.random_state) - X_train, X_dev, y_train, y_dev = train_test_split(X_train, y_train, - test_size=dataset_parameters.dev_size, - random_state=dataset_parameters.random_state) - - # TODO - if dataset_parameters.normalize: - pass - - return Dataset(task, dataset_parameters, X_train, - X_dev, X_test, y_train, y_dev, y_test) +from bolsonaro.data.dataset import Dataset +from bolsonaro.data.task import Task + +from sklearn.datasets import load_boston, load_iris, load_diabetes, load_digits, load_linnerud, load_wine, load_breast_cancer +from sklearn.datasets import fetch_olivetti_faces, fetch_20newsgroups, \ + fetch_20newsgroups_vectorized, fetch_lfw_people, fetch_lfw_pairs, \ + fetch_covtype, fetch_rcv1, fetch_kddcup99, fetch_california_housing +from sklearn.model_selection import train_test_split + + +class DatasetLoader(object): + + @staticmethod + def load_from_name(dataset_parameters): + name = dataset_parameters.name + if name == 'boston': + dataset_loading_func = load_boston + task = Task.REGRESSION + elif name == 'iris': + dataset_loading_func = load_iris + task = Task.CLASSIFICATION + elif name == 'diabetes': + dataset_loading_func = load_diabetes + task = Task.REGRESSION + elif name == 'digits': + dataset_loading_func = load_digits + task = Task.CLASSIFICATION + elif name == 'linnerud': + dataset_loading_func = load_linnerud + task = Task.REGRESSION + elif name == 'wine': + dataset_loading_func = load_wine + task = Task.CLASSIFICATION + elif name == 'breast_cancer': + dataset_loading_func = load_breast_cancer + task = Task.CLASSIFICATION + elif name == 'olivetti_faces': + dataset_loading_func = fetch_olivetti_faces + task = Task.CLASSIFICATION + elif name == '20newsgroups': + dataset_loading_func = fetch_20newsgroups + task = Task.CLASSIFICATION + elif name == '20newsgroups_vectorized': + dataset_loading_func = fetch_20newsgroups_vectorized + task = Task.CLASSIFICATION + elif name == 'lfw_people': + dataset_loading_func = fetch_lfw_people + task = Task.CLASSIFICATION + elif name == 'lfw_pairs': + dataset_loading_func = fetch_lfw_pairs + elif name == 'covtype': + dataset_loading_func = fetch_covtype + task = Task.CLASSIFICATION + elif name == 'rcv1': + dataset_loading_func = fetch_rcv1 + task = Task.CLASSIFICATION + elif name == 'kddcup99': + dataset_loading_func = fetch_kddcup99 + task = Task.CLASSIFICATION + elif name == 'california_housing': + dataset_loading_func = fetch_california_housing + task = Task.REGRESSION + else: + raise ValueError("Unsupported dataset '{}'".format(name)) + + X, y = dataset_loading_func(return_X_y=True) + X_train, X_test, y_train, y_test = train_test_split(X, y, + test_size=dataset_parameters.test_size, + random_state=dataset_parameters.random_state) + X_train, X_dev, y_train, y_dev = train_test_split(X_train, y_train, + test_size=dataset_parameters.dev_size, + random_state=dataset_parameters.random_state) + + # TODO + if dataset_parameters.normalize: + pass + + return Dataset(task, dataset_parameters, X_train, + X_dev, X_test, y_train, y_dev, y_test)