Skip to content
Snippets Groups Projects
Commit d1778016 authored by Luc Giffon's avatar Luc Giffon
Browse files

WIP implementation of Caltech101 dataset interface

parent 6937f7f1
No related branches found
No related tags found
No related merge requests found
import os import os
import pickle import pickle
import tarfile import tarfile
import time
import numpy as np import numpy as np
import imageio
import matplotlib.pyplot as plt
from skluc.main.data.mldatasets.ImageDataset import ImageDataset from skluc.main.data.mldatasets.ImageDataset import ImageDataset
from skluc.main.utils import LabeledData from skluc.main.utils import LabeledData, create_directory
from skluc.main.utils import logger, check_files from skluc.main.utils import logger, check_files
from scipy.misc import imresize
import logging
logging.getLogger("PIL.Image").setLevel(logging.WARNING)
class Caltech101Dataset(ImageDataset):
HEIGHT = None
WIDTH = None
DEPTH = None
class Cifar100FineDataset(ImageDataset): MAXIMUM_NBR_IMG_BY_CAT = 31
HEIGHT = 32
WIDTH = 32
DEPTH = 3
def __init__(self, validation_size=0, seed=None, s_download_dir=None): def __init__(self, images_shape=(32, 32, 3), test_size=1000, nb_img_by_cat=30, validation_size=0, seed=None, s_download_dir=None):
self.__s_url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" raise NotImplementedError("Caltech Dataset is not implemented yet")
self.__s_url = "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz"
self.meta = None self.meta = None
name = "cifar100fine" name = "caltech101"+"_"+("x".join([str(x) for x in images_shape]))
self.HEIGHT = images_shape[0]
self.WIDTH = images_shape[1]
self.DEPTH = images_shape[2]
self.__nb_img_by_cat = nb_img_by_cat
self.__test_size = test_size
if s_download_dir is not None: if s_download_dir is not None:
super().__init__([self.__s_url], name, s_download_dir, validation_size=validation_size, seed=seed) super().__init__([self.__s_url], name, s_download_dir, validation_size=validation_size, seed=seed)
else: else:
super().__init__([self.__s_url], name, validation_size=validation_size, seed=seed) super().__init__([self.__s_url], name, validation_size=validation_size, seed=seed)
self.__extracted_dirname = os.path.join(self.s_download_dir, "cifar-100-python") self.__extracted_dirname = os.path.join(self.s_download_dir, "101_ObjectCategories")
self.__extracted_files = [
'train',
'test',
'meta'
]
self.__extracted_file_paths = [os.path.join(self.__extracted_dirname, file) for file in self.__extracted_files] def get_caltech101_data(self):
full_images, labels = self.get_full_caltech101_data()
np.random.seed(self.seed)
full_idx = np.random.permutation(len(full_images))
idx_train = full_idx[:self.__test_size]
idx_test = full_idx[-self.__test_size:]
train = LabeledData(data=full_images[idx_train], labels=labels[idx_train])
test = LabeledData(data=full_images[idx_test], labels=labels[idx_test])
return train, test
def get_cifar100_data(self, keyword): def get_full_caltech101_data(self):
""" """
Get data from the files containing the keyword in their name. Get data from the files containing the keyword in their name.
:param keyword: :param keyword:
:return: :return:
""" """
full_data = [] data_dirpath = self.__extracted_dirname
full_labels = [] class_index = 0
for fpath in self.__extracted_file_paths: list_of_images = []
if keyword in fpath.split('/')[-1]: list_of_labels = []
with open(fpath, 'rb') as f: for class_name in os.listdir(data_dirpath):
pckl_data = pickle.load(f, encoding='bytes') list_of_images_for_class = []
full_data.append(pckl_data[b'data']) data_class_dirpath = os.path.join(data_dirpath, class_name)
full_labels.append(pckl_data[b'fine_labels']) for class_image_file in os.listdir(data_class_dirpath):
final_data = np.vstack(full_data) class_image_path = os.path.join(data_class_dirpath, class_image_file)
final_label = np.hstack(full_labels) im = imageio.imread(class_image_path)
im = imresize(im, (self.HEIGHT, self.WIDTH, self.DEPTH))
return final_data, final_label if len(im.shape) == 2:
# switch from grayscale to rgb
im = np.stack((im,) * 3, -1)
list_of_images_for_class.append(im.flatten())
np.random.seed(self.seed)
full_idx_class = np.random.permutation(len(list_of_images_for_class))
list_of_images.append(np.array(list_of_images_for_class)[full_idx_class][:self.MAXIMUM_NBR_IMG_BY_CAT])
list_of_labels.append(np.ones(self.MAXIMUM_NBR_IMG_BY_CAT) * class_index)
class_index += 1
return np.array(list_of_images), np.array(list_of_labels)
def get_meta(self): def get_meta(self):
""" """
...@@ -58,27 +87,35 @@ class Cifar100FineDataset(ImageDataset): ...@@ -58,27 +87,35 @@ class Cifar100FineDataset(ImageDataset):
:return: :return:
""" """
for fpath in self.__extracted_file_paths: pass
if 'meta' in fpath.split('/')[-1]:
with open(fpath, 'rb') as f:
pckl_data = pickle.load(f, encoding='bytes')
meta = pckl_data[b'fine_label_names']
return np.array(meta)
def read(self): def read(self):
npzdir_path = os.path.join(self.s_download_dir, "npzfiles")
lst_npzfile_paths = [os.path.join(npzdir_path, kw + ".npz")
for kw in self.data_groups_private]
create_directory(npzdir_path)
if check_files(lst_npzfile_paths):
# case npz files already exist
logger.debug("Files {} already exists".format(lst_npzfile_paths))
logger.info("Loading transformed data from files {}".format(lst_npzfile_paths))
self.load_npz(npzdir_path)
else:
targz_file_path = self.l_filepaths[-1] targz_file_path = self.l_filepaths[-1]
if not check_files(self.__extracted_file_paths): if not check_files([self.__extracted_dirname]):
logger.debug("Extracting {} ...".format(targz_file_path)) logger.debug("Extracting {} ...".format(targz_file_path))
tar = tarfile.open(targz_file_path, "r:gz") tar = tarfile.open(targz_file_path, "r:gz")
tar.extractall(path=self.s_download_dir) tar.extractall(path=self.s_download_dir)
else: else:
logger.debug("File {} has already been extracted".format(targz_file_path)) logger.debug("File {} has already been extracted".format(targz_file_path))
logger.debug("Get training data of dataset {}".format(self.s_name)) logger.debug("Get training and testing data of dataset {}".format(self.s_name))
self._train = LabeledData(*self.get_cifar100_data('train')) self._train, self._test = self.get_caltech101_data()
logger.debug("Get testing data of dataset {}".format(self.s_name))
self._test = LabeledData(*self.get_cifar100_data('test'))
self.meta = self.get_meta() self.meta = self.get_meta()
self._check_validation_size(self._train[0].shape[0]) self._check_validation_size(self._train[0].shape[0])
self.save_npz()
logger.debug("Number of labels in train set {}".format(len(np.unique(self._train.labels, axis=0))))
logger.debug("Number of labels in evaluation set {}".format(len(np.unique(self._test.labels, axis=0))))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment