import gzip import os import struct import numpy as np from skluc.main.data.mldatasets.ImageDataset import ImageDataset from skluc.main.utils import LabeledData from skluc.main.utils import logger class MnistDataset(ImageDataset): HEIGHT = 28 WIDTH = 28 DEPTH = 1 def __init__(self, validation_size=0, seed=0, s_download_dir=None): self.__s_root_url = "http://yann.lecun.com/exdb/mnist/" self.__d_leaf_url = { "train_data": "train-images-idx3-ubyte.gz", "train_label": "train-labels-idx1-ubyte.gz", "test_data": "t10k-images-idx3-ubyte.gz", "test_label": "t10k-labels-idx1-ubyte.gz" } l_url = [self.__s_root_url + leaf_url for leaf_url in self.__d_leaf_url.values()] if s_download_dir is not None: super().__init__(l_url, "mnist", s_download_dir, validation_size=validation_size, seed=seed) else: super().__init__(l_url, "mnist", validation_size=validation_size, seed=seed) @staticmethod def read_gziped_ubyte(fname_img=None, fname_lbl=None): """ loosely copied on https://gist.github.com/akesling/5358964 Python function for importing the MNIST data set. It returns an iterator of 2-tuples with the first element being the label and the second element being a numpy.uint8 2D array of pixel data for the given image. """ # Load everything in some numpy arrays logger.info("Read gziped ubyte file {}".format(fname_img)) with gzip.open(fname_lbl, 'rb') as flbl: magic, num = struct.unpack(">II", flbl.read(8)) lbl = np.fromstring(flbl.read(), dtype=np.int8) logger.info("Read gziped ubyte file {}".format(fname_lbl)) with gzip.open(fname_img, 'rb') as fimg: magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16)) img = np.fromstring(fimg.read(), dtype=np.uint8) img = img.reshape(len(lbl), -1) return img, lbl def read(self): """ set the _train and _test attribute of dataset """ # todo add possibility to provide percentage for validation set instead of size self._train = LabeledData( *self.read_gziped_ubyte(os.path.join(self.s_download_dir, self.__d_leaf_url["train_data"]), os.path.join(self.s_download_dir, self.__d_leaf_url["train_label"])) ) self._test = LabeledData( *self.read_gziped_ubyte(os.path.join(self.s_download_dir, self.__d_leaf_url["test_data"]), os.path.join(self.s_download_dir, self.__d_leaf_url["test_label"])) ) self._check_validation_size(self._train[0].shape[0])