diff --git a/skluc/main/data/transformation/VGG19Transformer.py b/skluc/main/data/transformation/VGG19Transformer.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..cfbbdde0cea5d83c0b65aac458c84cfe5d2b03d8 100644 --- a/skluc/main/data/transformation/VGG19Transformer.py +++ b/skluc/main/data/transformation/VGG19Transformer.py @@ -0,0 +1,93 @@ +from keras import Model +from keras.models import load_model + +from skluc.main.data.mldatasets.Cifar10Dataset import Cifar10Dataset +from skluc.main.data.transformation.KerasModelTransformer import KerasModelTransformer +from skluc.main.utils import logger, create_directory, download_data, check_file_md5, Singleton, DownloadableModel + + +class VGG19Transformer(KerasModelTransformer, metaclass=Singleton): + """ + Uses the vgg19 convolution network to transform data. + """ + + MAP_DATA_MODEL = { + "svhn": DownloadableModel( + url="https://pageperso.lis-lab.fr/~luc.giffon/models/1529968150.5454917_vgg19_svhn.h5", + checksum="563a9ec2aad37459bd1ed0e329441b05" + ), + "cifar100": DownloadableModel( + url="https://pageperso.lis-lab.fr/~luc.giffon/models/1530965727.781668_vgg19_cifar100fine.h5", + checksum="edf43e263fec05e2c013dd5a2128fc38" + ), + "cifar10": DownloadableModel( + url="https://pageperso.lis-lab.fr/~luc.giffon/models/1522967518.1916964_vgg19_cifar10.h5", + checksum="0dbb4f02ceb1f4acb6e24831758106e5" + ), + "siamese_omniglot_28x28": DownloadableModel( + url="https://pageperso.lis-lab.fr/~luc.giffon/models/1536244775.6502118_siamese_vgg19_omniglot_28x28_conv.h5", + checksum="90aec06e688ec3248ba89544a10c9f1f" + ), + "omniglot_28x28": DownloadableModel( + url="https://pageperso.lis-lab.fr/~luc.giffon/models/1536764034.66037_vgg19_omniglot.h5", + checksum="ef1272e9c7ce070e8f70889ec58d1c33" + ) + } + + def __init__(self, data_name, cut_layer_name=None, cut_layer_index=None): + if data_name not in self.MAP_DATA_MODEL.keys(): + raise ValueError("Unknown data name. Can't load weights") + + if cut_layer_name is None and cut_layer_index is None: + logger.warning( + "Cut layer chosen automatically but it eventually will lead to an error in future: block5_pool should be specified explicitly") + cut_layer_name = "block5_pool" + if cut_layer_name is not None: + transformation_name = str(data_name) + "_" + self.__class__.__name__ + "_" + str(cut_layer_name) + elif cut_layer_index is not None: + transformation_name = str(data_name) + "_" + self.__class__.__name__ \ + + "_" + str(cut_layer_index) + # todo sauvegarder index / nom dans le meme dossier si c'est les meme + else: + raise AttributeError("Cut layer name or cut_layer index must be given to init VGG19Transformer.") + self.__cut_layer_name = cut_layer_name + self.__cut_layer_index = cut_layer_index + + self.keras_model = None + + super().__init__(data_name=data_name, + transformation_name=transformation_name) + + def load(self): + create_directory(self.s_download_dir) + s_model_path = download_data(self.MAP_DATA_MODEL[self.data_name].url, self.s_download_dir) + check_file_md5(s_model_path, self.MAP_DATA_MODEL[self.data_name].checksum) + if self.keras_model is None: + logger.debug("Loading VGG19 model for {} transformation with {} weights".format(self.transformation_name, self.data_name)) + self.keras_model = load_model(s_model_path) + + logger.debug("Layers of model {}".format([l.name for l in self.keras_model.layers])) + + if self.__cut_layer_index is not None: + cut_layer = self.keras_model.layers[-1] + self.__cut_layer_name = cut_layer.name + logger.debug( + "Found associated layer {} to layer index {}".format(self.__cut_layer_name, self.__cut_layer_index)) + + self.keras_model = Model(inputs=self.keras_model.input, + outputs=self.keras_model.get_layer(name=self.__cut_layer_name).output) + + else: + logger.debug("Skip loading model VGG19 for {} transformation with {} weights. Already there.".format( + self.transformation_name, + self.data_name)) + + +if __name__ == '__main__': + valsize = 10000 + d = Cifar10Dataset(validation_size=valsize) + + d.load() + d.to_image() + trans = VGG19Transformer(data_name="cifar10", cut_layer_name="block5_pool") + d.apply_transformer(transformer=trans)