diff --git a/skluc/main/data/mldatasets/MnistDataset.py b/skluc/main/data/mldatasets/MnistDataset.py index f2adb8f291706f86c97d0300ab572497ae64bb0d..0959025212893ce1eb4f641d5d1c79e312ce509c 100644 --- a/skluc/main/data/mldatasets/MnistDataset.py +++ b/skluc/main/data/mldatasets/MnistDataset.py @@ -55,11 +55,7 @@ class MnistDataset(ImageDataset): def read(self): """ - Return a dict of data where, for each key is associated a (data, label) tuple. - - The values of the tuple are np.ndarray. - - :return: dict + set the _train and _test attribute of dataset """ # todo add possibility to provide percentage for validation set instead of size self._train = LabeledData( diff --git a/skluc/test/test_data/test_mldatasets/TestMnistDataset.py b/skluc/test/test_data/test_mldatasets/TestMnistDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c2928edf8598fa6cfc205973bd01b7ea59ad4a09 --- /dev/null +++ b/skluc/test/test_data/test_mldatasets/TestMnistDataset.py @@ -0,0 +1,23 @@ +import os +import unittest + +from skluc.main.data.mldatasets import MnistDataset + + +class TestMnistDataset(unittest.TestCase): + + def test_mnist(self): + mnist = MnistDataset() + mnist.load() + for name in mnist.l_filepaths: + self.assertTrue(os.path.exists(name)) + + def test_to_image(self): + mnist = MnistDataset() + mnist.load() + mnist.to_image() + self.assertTrue(mnist.train.data[0].shape == (28, 28, 1)) + + +if __name__ == "__main__": + unittest.main()