import unittest

from skluc.main.data.mldatasets.Cifar10Dataset import Cifar10Dataset

from skluc.main.data.transformation.VGG19Transformer import VGG19Transformer

from skluc.main.utils import logger


class TestVGG19Transformer(unittest.TestCase):
    def setUp(self):
        self.lst_name_cut_layers = [
            "block3_pool"
        ]
        self.dict_datasets = {
            "cifar10": Cifar10Dataset,
        }

    def test_transform(self):
        valsize = 10000

        for data_name in self.dict_datasets:
            logger.info("Testing dataset {}".format(data_name))
            trans = None
            for name_cut_layer in self.lst_name_cut_layers:
                logger.info("Testing cut layer {}".format(name_cut_layer))
                dataset = self.dict_datasets[data_name]
                d = dataset(validation_size=valsize)
                d.load()
                d.flatten()
                d.to_image()
                trans = VGG19Transformer(data_name=data_name, cut_layer_name=name_cut_layer)
                d.apply_transformer(transformer=trans)
            del trans

    def test_init(self):
        for data_name in self.dict_datasets:
            logger.info("Testing dataset {}".format(data_name))
            trans = None
            for name_cut_layer in self.lst_name_cut_layers:
                logger.info("Testing cut layer {}".format(name_cut_layer))
                trans = VGG19Transformer(data_name=data_name, cut_layer_name=name_cut_layer)
                logger.debug(trans.keras_model)
            del trans


if __name__ == '__main__':
    unittest.main()