Skip to content
Snippets Groups Projects
Select Git revision
  • ce870e9bd834992bc665c8e6c19591792cbeece6
  • master default protected
2 results

ModelAdd.vue

Blame
  • TestVGG19Transformer.py 1.58 KiB
    import unittest
    
    from skluc.main.data.mldatasets.Cifar10Dataset import Cifar10Dataset
    
    from skluc.main.data.transformation.VGG19Transformer import VGG19Transformer
    
    from skluc.main.utils import logger
    
    
    class TestVGG19Transformer(unittest.TestCase):
        def setUp(self):
            self.lst_name_cut_layers = [
                "block3_pool"
            ]
            self.dict_datasets = {
                "cifar10": Cifar10Dataset,
            }
    
        def test_transform(self):
            valsize = 10000
    
            for data_name in self.dict_datasets:
                logger.info("Testing dataset {}".format(data_name))
                trans = None
                for name_cut_layer in self.lst_name_cut_layers:
                    logger.info("Testing cut layer {}".format(name_cut_layer))
                    dataset = self.dict_datasets[data_name]
                    d = dataset(validation_size=valsize)
                    d.load()
                    d.flatten()
                    d.to_image()
                    trans = VGG19Transformer(data_name=data_name, cut_layer_name=name_cut_layer)
                    d.apply_transformer(transformer=trans)
                del trans
    
        def test_init(self):
            for data_name in self.dict_datasets:
                logger.info("Testing dataset {}".format(data_name))
                trans = None
                for name_cut_layer in self.lst_name_cut_layers:
                    logger.info("Testing cut layer {}".format(name_cut_layer))
                    trans = VGG19Transformer(data_name=data_name, cut_layer_name=name_cut_layer)
                    logger.debug(trans.keras_model)
                del trans
    
    
    if __name__ == '__main__':
        unittest.main()