From 8613f4326c5f4552d741feaaf71f629c1c8bba06 Mon Sep 17 00:00:00 2001 From: Luc Giffon <luc.giffon@lis-lab.fr> Date: Tue, 4 Sep 2018 11:06:55 +0200 Subject: [PATCH] Mnist dataset tested --- skluc/main/data/mldatasets/MnistDataset.py | 6 +---- .../test_mldatasets/TestMnistDataset.py | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+), 5 deletions(-) create mode 100644 skluc/test/test_data/test_mldatasets/TestMnistDataset.py diff --git a/skluc/main/data/mldatasets/MnistDataset.py b/skluc/main/data/mldatasets/MnistDataset.py index f2adb8f..0959025 100644 --- a/skluc/main/data/mldatasets/MnistDataset.py +++ b/skluc/main/data/mldatasets/MnistDataset.py @@ -55,11 +55,7 @@ class MnistDataset(ImageDataset): def read(self): """ - Return a dict of data where, for each key is associated a (data, label) tuple. - - The values of the tuple are np.ndarray. - - :return: dict + set the _train and _test attribute of dataset """ # todo add possibility to provide percentage for validation set instead of size self._train = LabeledData( diff --git a/skluc/test/test_data/test_mldatasets/TestMnistDataset.py b/skluc/test/test_data/test_mldatasets/TestMnistDataset.py new file mode 100644 index 0000000..c2928ed --- /dev/null +++ b/skluc/test/test_data/test_mldatasets/TestMnistDataset.py @@ -0,0 +1,23 @@ +import os +import unittest + +from skluc.main.data.mldatasets import MnistDataset + + +class TestMnistDataset(unittest.TestCase): + + def test_mnist(self): + mnist = MnistDataset() + mnist.load() + for name in mnist.l_filepaths: + self.assertTrue(os.path.exists(name)) + + def test_to_image(self): + mnist = MnistDataset() + mnist.load() + mnist.to_image() + self.assertTrue(mnist.train.data[0].shape == (28, 28, 1)) + + +if __name__ == "__main__": + unittest.main() -- GitLab