diff --git a/skluc/main/data/mldatasets/Caltech101Dataset.py b/skluc/main/data/mldatasets/Caltech101Dataset.py index 6bb7ff87b9bafb42098a94c4017b14fa11b10903..a3363a8155637998dd18aa9ccf9a93b760b2ed96 100644 --- a/skluc/main/data/mldatasets/Caltech101Dataset.py +++ b/skluc/main/data/mldatasets/Caltech101Dataset.py @@ -1,56 +1,85 @@ import os import pickle import tarfile - +import time import numpy as np - +import imageio +import matplotlib.pyplot as plt from skluc.main.data.mldatasets.ImageDataset import ImageDataset -from skluc.main.utils import LabeledData +from skluc.main.utils import LabeledData, create_directory from skluc.main.utils import logger, check_files +from scipy.misc import imresize +import logging +logging.getLogger("PIL.Image").setLevel(logging.WARNING) -class Cifar100FineDataset(ImageDataset): - HEIGHT = 32 - WIDTH = 32 - DEPTH = 3 +class Caltech101Dataset(ImageDataset): + HEIGHT = None + WIDTH = None + DEPTH = None - def __init__(self, validation_size=0, seed=None, s_download_dir=None): - self.__s_url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" + MAXIMUM_NBR_IMG_BY_CAT = 31 + + def __init__(self, images_shape=(32, 32, 3), test_size=1000, nb_img_by_cat=30, validation_size=0, seed=None, s_download_dir=None): + raise NotImplementedError("Caltech Dataset is not implemented yet") + self.__s_url = "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz" self.meta = None - name = "cifar100fine" + name = "caltech101"+"_"+("x".join([str(x) for x in images_shape])) + + self.HEIGHT = images_shape[0] + self.WIDTH = images_shape[1] + self.DEPTH = images_shape[2] + + self.__nb_img_by_cat = nb_img_by_cat + self.__test_size = test_size + if s_download_dir is not None: super().__init__([self.__s_url], name, s_download_dir, validation_size=validation_size, seed=seed) else: super().__init__([self.__s_url], name, validation_size=validation_size, seed=seed) - self.__extracted_dirname = os.path.join(self.s_download_dir, "cifar-100-python") - self.__extracted_files = [ - 'train', - 'test', - 'meta' - ] + self.__extracted_dirname = os.path.join(self.s_download_dir, "101_ObjectCategories") - self.__extracted_file_paths = [os.path.join(self.__extracted_dirname, file) for file in self.__extracted_files] + def get_caltech101_data(self): + full_images, labels = self.get_full_caltech101_data() + np.random.seed(self.seed) + full_idx = np.random.permutation(len(full_images)) + idx_train = full_idx[:self.__test_size] + idx_test = full_idx[-self.__test_size:] + train = LabeledData(data=full_images[idx_train], labels=labels[idx_train]) + test = LabeledData(data=full_images[idx_test], labels=labels[idx_test]) + return train, test - def get_cifar100_data(self, keyword): + def get_full_caltech101_data(self): """ Get data from the files containing the keyword in their name. :param keyword: :return: """ - full_data = [] - full_labels = [] - for fpath in self.__extracted_file_paths: - if keyword in fpath.split('/')[-1]: - with open(fpath, 'rb') as f: - pckl_data = pickle.load(f, encoding='bytes') - full_data.append(pckl_data[b'data']) - full_labels.append(pckl_data[b'fine_labels']) - final_data = np.vstack(full_data) - final_label = np.hstack(full_labels) - - return final_data, final_label + data_dirpath = self.__extracted_dirname + class_index = 0 + list_of_images = [] + list_of_labels = [] + for class_name in os.listdir(data_dirpath): + list_of_images_for_class = [] + data_class_dirpath = os.path.join(data_dirpath, class_name) + for class_image_file in os.listdir(data_class_dirpath): + class_image_path = os.path.join(data_class_dirpath, class_image_file) + im = imageio.imread(class_image_path) + im = imresize(im, (self.HEIGHT, self.WIDTH, self.DEPTH)) + if len(im.shape) == 2: + # switch from grayscale to rgb + im = np.stack((im,) * 3, -1) + list_of_images_for_class.append(im.flatten()) + + np.random.seed(self.seed) + full_idx_class = np.random.permutation(len(list_of_images_for_class)) + list_of_images.append(np.array(list_of_images_for_class)[full_idx_class][:self.MAXIMUM_NBR_IMG_BY_CAT]) + list_of_labels.append(np.ones(self.MAXIMUM_NBR_IMG_BY_CAT) * class_index) + + class_index += 1 + return np.array(list_of_images), np.array(list_of_labels) def get_meta(self): """ @@ -58,27 +87,35 @@ class Cifar100FineDataset(ImageDataset): :return: """ - for fpath in self.__extracted_file_paths: - if 'meta' in fpath.split('/')[-1]: - with open(fpath, 'rb') as f: - pckl_data = pickle.load(f, encoding='bytes') - meta = pckl_data[b'fine_label_names'] - return np.array(meta) + pass def read(self): - targz_file_path = self.l_filepaths[-1] - if not check_files(self.__extracted_file_paths): - logger.debug("Extracting {} ...".format(targz_file_path)) - tar = tarfile.open(targz_file_path, "r:gz") - tar.extractall(path=self.s_download_dir) + npzdir_path = os.path.join(self.s_download_dir, "npzfiles") + lst_npzfile_paths = [os.path.join(npzdir_path, kw + ".npz") + for kw in self.data_groups_private] + create_directory(npzdir_path) + if check_files(lst_npzfile_paths): + # case npz files already exist + logger.debug("Files {} already exists".format(lst_npzfile_paths)) + logger.info("Loading transformed data from files {}".format(lst_npzfile_paths)) + self.load_npz(npzdir_path) + else: - logger.debug("File {} has already been extracted".format(targz_file_path)) + targz_file_path = self.l_filepaths[-1] + if not check_files([self.__extracted_dirname]): + logger.debug("Extracting {} ...".format(targz_file_path)) + tar = tarfile.open(targz_file_path, "r:gz") + tar.extractall(path=self.s_download_dir) + else: + logger.debug("File {} has already been extracted".format(targz_file_path)) + + logger.debug("Get training and testing data of dataset {}".format(self.s_name)) + self._train, self._test = self.get_caltech101_data() - logger.debug("Get training data of dataset {}".format(self.s_name)) - self._train = LabeledData(*self.get_cifar100_data('train')) + self.meta = self.get_meta() + self._check_validation_size(self._train[0].shape[0]) + self.save_npz() - logger.debug("Get testing data of dataset {}".format(self.s_name)) - self._test = LabeledData(*self.get_cifar100_data('test')) - self.meta = self.get_meta() + logger.debug("Number of labels in train set {}".format(len(np.unique(self._train.labels, axis=0)))) + logger.debug("Number of labels in evaluation set {}".format(len(np.unique(self._test.labels, axis=0)))) - self._check_validation_size(self._train[0].shape[0])