Skip to content
Snippets Groups Projects
Commit d1778016 authored by Luc Giffon's avatar Luc Giffon
Browse files

WIP implementation of Caltech101 dataset interface

parent 6937f7f1
No related branches found
No related tags found
No related merge requests found
import os
import pickle
import tarfile
import time
import numpy as np
import imageio
import matplotlib.pyplot as plt
from skluc.main.data.mldatasets.ImageDataset import ImageDataset
from skluc.main.utils import LabeledData
from skluc.main.utils import LabeledData, create_directory
from skluc.main.utils import logger, check_files
from scipy.misc import imresize
import logging
logging.getLogger("PIL.Image").setLevel(logging.WARNING)
class Caltech101Dataset(ImageDataset):
HEIGHT = None
WIDTH = None
DEPTH = None
class Cifar100FineDataset(ImageDataset):
HEIGHT = 32
WIDTH = 32
DEPTH = 3
MAXIMUM_NBR_IMG_BY_CAT = 31
def __init__(self, validation_size=0, seed=None, s_download_dir=None):
self.__s_url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
def __init__(self, images_shape=(32, 32, 3), test_size=1000, nb_img_by_cat=30, validation_size=0, seed=None, s_download_dir=None):
raise NotImplementedError("Caltech Dataset is not implemented yet")
self.__s_url = "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz"
self.meta = None
name = "cifar100fine"
name = "caltech101"+"_"+("x".join([str(x) for x in images_shape]))
self.HEIGHT = images_shape[0]
self.WIDTH = images_shape[1]
self.DEPTH = images_shape[2]
self.__nb_img_by_cat = nb_img_by_cat
self.__test_size = test_size
if s_download_dir is not None:
super().__init__([self.__s_url], name, s_download_dir, validation_size=validation_size, seed=seed)
else:
super().__init__([self.__s_url], name, validation_size=validation_size, seed=seed)
self.__extracted_dirname = os.path.join(self.s_download_dir, "cifar-100-python")
self.__extracted_files = [
'train',
'test',
'meta'
]
self.__extracted_dirname = os.path.join(self.s_download_dir, "101_ObjectCategories")
self.__extracted_file_paths = [os.path.join(self.__extracted_dirname, file) for file in self.__extracted_files]
def get_caltech101_data(self):
full_images, labels = self.get_full_caltech101_data()
np.random.seed(self.seed)
full_idx = np.random.permutation(len(full_images))
idx_train = full_idx[:self.__test_size]
idx_test = full_idx[-self.__test_size:]
train = LabeledData(data=full_images[idx_train], labels=labels[idx_train])
test = LabeledData(data=full_images[idx_test], labels=labels[idx_test])
return train, test
def get_cifar100_data(self, keyword):
def get_full_caltech101_data(self):
"""
Get data from the files containing the keyword in their name.
:param keyword:
:return:
"""
full_data = []
full_labels = []
for fpath in self.__extracted_file_paths:
if keyword in fpath.split('/')[-1]:
with open(fpath, 'rb') as f:
pckl_data = pickle.load(f, encoding='bytes')
full_data.append(pckl_data[b'data'])
full_labels.append(pckl_data[b'fine_labels'])
final_data = np.vstack(full_data)
final_label = np.hstack(full_labels)
return final_data, final_label
data_dirpath = self.__extracted_dirname
class_index = 0
list_of_images = []
list_of_labels = []
for class_name in os.listdir(data_dirpath):
list_of_images_for_class = []
data_class_dirpath = os.path.join(data_dirpath, class_name)
for class_image_file in os.listdir(data_class_dirpath):
class_image_path = os.path.join(data_class_dirpath, class_image_file)
im = imageio.imread(class_image_path)
im = imresize(im, (self.HEIGHT, self.WIDTH, self.DEPTH))
if len(im.shape) == 2:
# switch from grayscale to rgb
im = np.stack((im,) * 3, -1)
list_of_images_for_class.append(im.flatten())
np.random.seed(self.seed)
full_idx_class = np.random.permutation(len(list_of_images_for_class))
list_of_images.append(np.array(list_of_images_for_class)[full_idx_class][:self.MAXIMUM_NBR_IMG_BY_CAT])
list_of_labels.append(np.ones(self.MAXIMUM_NBR_IMG_BY_CAT) * class_index)
class_index += 1
return np.array(list_of_images), np.array(list_of_labels)
def get_meta(self):
"""
......@@ -58,27 +87,35 @@ class Cifar100FineDataset(ImageDataset):
:return:
"""
for fpath in self.__extracted_file_paths:
if 'meta' in fpath.split('/')[-1]:
with open(fpath, 'rb') as f:
pckl_data = pickle.load(f, encoding='bytes')
meta = pckl_data[b'fine_label_names']
return np.array(meta)
pass
def read(self):
npzdir_path = os.path.join(self.s_download_dir, "npzfiles")
lst_npzfile_paths = [os.path.join(npzdir_path, kw + ".npz")
for kw in self.data_groups_private]
create_directory(npzdir_path)
if check_files(lst_npzfile_paths):
# case npz files already exist
logger.debug("Files {} already exists".format(lst_npzfile_paths))
logger.info("Loading transformed data from files {}".format(lst_npzfile_paths))
self.load_npz(npzdir_path)
else:
targz_file_path = self.l_filepaths[-1]
if not check_files(self.__extracted_file_paths):
if not check_files([self.__extracted_dirname]):
logger.debug("Extracting {} ...".format(targz_file_path))
tar = tarfile.open(targz_file_path, "r:gz")
tar.extractall(path=self.s_download_dir)
else:
logger.debug("File {} has already been extracted".format(targz_file_path))
logger.debug("Get training data of dataset {}".format(self.s_name))
self._train = LabeledData(*self.get_cifar100_data('train'))
logger.debug("Get training and testing data of dataset {}".format(self.s_name))
self._train, self._test = self.get_caltech101_data()
logger.debug("Get testing data of dataset {}".format(self.s_name))
self._test = LabeledData(*self.get_cifar100_data('test'))
self.meta = self.get_meta()
self._check_validation_size(self._train[0].shape[0])
self.save_npz()
logger.debug("Number of labels in train set {}".format(len(np.unique(self._train.labels, axis=0))))
logger.debug("Number of labels in evaluation set {}".format(len(np.unique(self._test.labels, axis=0))))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment