Skip to content
Snippets Groups Projects
Commit a0194a92 authored by Luc Giffon's avatar Luc Giffon
Browse files

converting images vectors to matrices + add the data attribute to the Dataset class

parent a596c2b9
Branches
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@ This module defines the Dataset classes usefull for downloading and loading data
The currently implemented datasets are:
- mnist
- cifar10
"""
import urllib.request
......@@ -89,6 +90,7 @@ class Dataset:
self.s_name = s_name
self.s_download_dir = os.path.join(s_download_dir, self.s_name)
self.l_filepaths = [os.path.join(self.s_download_dir, fname) for fname in self.l_filenames]
self.data = {}
def download(self):
"""
......@@ -181,7 +183,22 @@ class MnistDataset(Dataset):
"test": (self.read_gziped_ubyte(os.path.join(self.s_download_dir, self.__d_leaf_url["test_data"]),
os.path.join(self.s_download_dir, self.__d_leaf_url["test_label"])))
}
return d_data
self.data = d_data
return self.data
def to_image(self):
"""
Modify data to present it like images (matrices) instead of vectors.
:return: The modified data.
"""
for kw in ["train", "test"]:
images_vec = self.data[kw][0]
labels = self.data[kw][1]
images_mat = np.reshape(images_vec, (images_vec.shape[0], 784, 1), order='F')
images = np.reshape(images_mat, (images_mat.shape[0], 28, 28, 1), order='C')
self.data[kw] = (images, labels)
return self.data
class Cifar10Dataset(Dataset):
......@@ -219,11 +236,7 @@ class Cifar10Dataset(Dataset):
full_data.append(pckl_data[b'data'])
full_labels.append(pckl_data[b'labels'])
final_data = np.vstack(full_data)
n_examples = final_data.shape[0]
# the data are stored like each line is an image where the 1024 first value are for the red, then 1024 for G
# then 1024 for Blue so the order is 'F'
final_data = np.reshape(final_data, (n_examples, 1024, 3), order='F')
final_data = np.reshape(final_data, (n_examples, 32, 32, 3), order='C')
final_label = np.hstack(full_labels)
return final_data, final_label
......@@ -245,12 +258,27 @@ class Cifar10Dataset(Dataset):
else:
logger.debug("File {} has already been extracted".format(targz_file_path))
data = {
"train": self.get_cifar10_data('data'),
"test": self.get_cifar10_data('test'),
"meta": self.get_meta()
}
return data
self.data["train"] = self.get_cifar10_data('data')
self.data["test"] = self.get_cifar10_data('test')
self.data["meta"] = self.get_meta()
return self.data
def to_image(self):
"""
Modify data to present it like images (matrices) instead of vectors.
:return: The modified data.
"""
for kw in ["train", "test"]:
images_vec = self.data[kw][0]
labels = self.data[kw][1]
# the data are stored like each line is an image where the 1024 first value are for the red, then 1024 for G
# then 1024 for Blue so the order is 'F'
images_mat = np.reshape(images_vec, (images_vec.shape[0], 1024, 3), order='F')
images = np.reshape(images_mat, (images_mat.shape[0], 32, 32, 3), order='C')
self.data[kw] = (images, labels)
return self.data
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment