import gzip
import os
import struct

import numpy as np

from skluc.main.data.mldatasets.ImageDataset import ImageDataset
from skluc.main.utils import LabeledData
from skluc.main.utils import logger


class MnistDataset(ImageDataset):

    HEIGHT = 28
    WIDTH = 28
    DEPTH = 1

    def __init__(self, validation_size=0, seed=0, s_download_dir=None):
        self.__s_root_url = "http://yann.lecun.com/exdb/mnist/"
        self.__d_leaf_url = {
            "train_data": "train-images-idx3-ubyte.gz",
            "train_label": "train-labels-idx1-ubyte.gz",
            "test_data": "t10k-images-idx3-ubyte.gz",
            "test_label": "t10k-labels-idx1-ubyte.gz"
        }

        l_url = [self.__s_root_url + leaf_url for leaf_url in self.__d_leaf_url.values()]
        if s_download_dir is not None:
            super().__init__(l_url, "mnist", s_download_dir, validation_size=validation_size, seed=seed)
        else:
            super().__init__(l_url, "mnist", validation_size=validation_size, seed=seed)

    @staticmethod
    def read_gziped_ubyte(fname_img=None, fname_lbl=None):
        """
        loosely copied on https://gist.github.com/akesling/5358964

        Python function for importing the MNIST data set.  It returns an iterator
        of 2-tuples with the first element being the label and the second element
        being a numpy.uint8 2D array of pixel data for the given image.
        """
        # Load everything in some numpy arrays
        logger.info("Read gziped ubyte file {}".format(fname_img))
        with gzip.open(fname_lbl, 'rb') as flbl:
            magic, num = struct.unpack(">II", flbl.read(8))
            lbl = np.fromstring(flbl.read(), dtype=np.int8)

        logger.info("Read gziped ubyte file {}".format(fname_lbl))
        with gzip.open(fname_img, 'rb') as fimg:
            magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
            img = np.fromstring(fimg.read(), dtype=np.uint8)
            img = img.reshape(len(lbl), -1)

        return img, lbl

    def read(self):
        """
        set the _train and _test attribute of dataset
        """
        # todo add possibility to provide percentage for validation set instead of size
        self._train = LabeledData(
            *self.read_gziped_ubyte(os.path.join(self.s_download_dir, self.__d_leaf_url["train_data"]),
                                    os.path.join(self.s_download_dir, self.__d_leaf_url["train_label"]))
        )

        self._test = LabeledData(
            *self.read_gziped_ubyte(os.path.join(self.s_download_dir, self.__d_leaf_url["test_data"]),
                                    os.path.join(self.s_download_dir, self.__d_leaf_url["test_label"]))
        )

        self._check_validation_size(self._train[0].shape[0])