Skip to content
Snippets Groups Projects
Commit aed430d4 authored by Luc Giffon's avatar Luc Giffon
Browse files

remove RPS dataset, irelevant

parent 2c587e3b
No related branches found
No related tags found
No related merge requests found
import os
import zipfile
import numpy as np
import imageio
import matplotlib.pyplot as plt
from skluc.utils import LabeledData, create_directory
from skluc.data.mldatasets.ImageDataset import ImageDataset
from skluc.utils import logger, check_files
class RPSDataset(ImageDataset):
data_groups_private = ["_train"]
HEIGHT = 50
WIDTH = 50
DEPTH = 3
TRAIN_SIZE = 600
def __init__(self, validation_size=0, seed=None, s_download_dir=None):
self.__s_url = ["https://pageperso.lif.univ-mrs.fr/~luc.giffon/datasets/rps_data_resize.zip"]
self.meta = None
name = "rps"
if s_download_dir is not None:
super().__init__(self.__s_url, name, s_download_dir, validation_size=validation_size, seed=seed)
else:
super().__init__(self.__s_url, name, validation_size=validation_size, seed=seed)
self.__extracted_dirs = [
os.path.join(self.s_download_dir, "images_background"),
os.path.join(self.s_download_dir, "images_evaluation")
]
def get_rps_data(self):
data_dirname = "rps_data_resize"
data_dirpath = os.path.join(self.s_download_dir, data_dirname)
class_index = 0
list_of_images = []
list_of_labels = []
for symbol_name in os.listdir(data_dirpath):
data_symbol_path = os.path.join(data_dirpath, symbol_name)
for symbol_image_file in os.listdir(data_symbol_path):
symbol_image_path = os.path.join(data_symbol_path, symbol_image_file)
im = imageio.imread(symbol_image_path)
list_of_images.append(im)
list_of_labels.append(class_index)
class_index += 1
data = np.array(list_of_images)
labels = np.array(list_of_labels)
data = data.reshape(data.shape[0], self.WIDTH*self.HEIGHT, self.DEPTH, order="C")
data = data.reshape(data.shape[0], self.WIDTH*self.HEIGHT*self.DEPTH, order="F")
return data, labels
def read(self):
npzdir_path = os.path.join(self.s_download_dir, "npzfiles")
lst_npzfile_paths = [os.path.join(npzdir_path, kw + ".npz")
for kw in self.data_groups_private]
create_directory(npzdir_path)
if check_files(lst_npzfile_paths):
# case npz files already exist
logger.debug("Files {} already exists".format(lst_npzfile_paths))
logger.info("Loading transformed data from files {}".format(lst_npzfile_paths))
for kw in self.data_groups_private:
npzfile_path = os.path.join(npzdir_path, kw + ".npz")
logger.debug("Loading {}".format(npzfile_path))
npzfile = np.load(npzfile_path)
data = npzfile[kw + "_data"]
logger.debug("Shape of {} set: {}".format(kw, data.shape))
labels = npzfile[kw + "_labels"]
setattr(self, kw, LabeledData(data=data, labels=labels))
else:
if not check_files(self.__extracted_dirs):
# case zip files dont even exist
logger.debug("Extracting {} ...".format(self.l_filepaths))
for zip_file in self.l_filepaths:
zip_ref = zipfile.ZipFile(zip_file, 'r')
zip_ref.extractall(self.s_download_dir)
zip_ref.close()
else:
logger.debug("Files {} have already been extracted".format(self.l_filepaths))
full_data, full_labels = self.get_rps_data()
logger.debug("Get training data of dataset {}".format(self.s_name))
self._train = LabeledData(data=full_data, labels=full_labels)
# self._test = LabeledData(data=np.array([]), labels=np.array([]))
#
# logger.debug("Get testing data of dataset {}".format(self.s_name))
# self._test = LabeledData(*self.get_omniglot_data('evaluation'))
#
self._check_validation_size(self._train[0].shape[0])
self.save_npz()
@property
def train(self):
indexes = self.permuted_index_train[:self.TRAIN_SIZE - self.validation_size]
return LabeledData(data=self._train.data[indexes],
labels=self._train.labels[indexes])
@property
def test(self):
indexes = self.permuted_index_train[self.TRAIN_SIZE:]
return LabeledData(data=self._train.data[indexes],
labels=self._train.labels[indexes])
@property
def validation(self):
indexes = self.permuted_index_train[(self.TRAIN_SIZE - self.validation_size):self.TRAIN_SIZE]
return LabeledData(data=self._train.data[indexes],
labels=self._train.labels[indexes])
if __name__ == "__main__":
import time
d = RPSDataset(validation_size=100)
d.load()
d.to_image()
print(d.train.data.shape)
for i, im in enumerate(d.train.data):
plt.imshow(im)
plt.show()
print(d.train.labels[i])
time.sleep(1)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment