Skip to content
Snippets Groups Projects
Commit 009ab0f5 authored by Luc Giffon's avatar Luc Giffon
Browse files

move stuff from old VGG19transformer folder in file VGG19transformer

parent 9382dc5a
No related branches found
No related tags found
No related merge requests found
from keras import Model
from keras.models import load_model
from skluc.main.data.mldatasets.Cifar10Dataset import Cifar10Dataset
from skluc.main.data.transformation.KerasModelTransformer import KerasModelTransformer
from skluc.main.utils import logger, create_directory, download_data, check_file_md5, Singleton, DownloadableModel
class VGG19Transformer(KerasModelTransformer, metaclass=Singleton):
"""
Uses the vgg19 convolution network to transform data.
"""
MAP_DATA_MODEL = {
"svhn": DownloadableModel(
url="https://pageperso.lis-lab.fr/~luc.giffon/models/1529968150.5454917_vgg19_svhn.h5",
checksum="563a9ec2aad37459bd1ed0e329441b05"
),
"cifar100": DownloadableModel(
url="https://pageperso.lis-lab.fr/~luc.giffon/models/1530965727.781668_vgg19_cifar100fine.h5",
checksum="edf43e263fec05e2c013dd5a2128fc38"
),
"cifar10": DownloadableModel(
url="https://pageperso.lis-lab.fr/~luc.giffon/models/1522967518.1916964_vgg19_cifar10.h5",
checksum="0dbb4f02ceb1f4acb6e24831758106e5"
),
"siamese_omniglot_28x28": DownloadableModel(
url="https://pageperso.lis-lab.fr/~luc.giffon/models/1536244775.6502118_siamese_vgg19_omniglot_28x28_conv.h5",
checksum="90aec06e688ec3248ba89544a10c9f1f"
),
"omniglot_28x28": DownloadableModel(
url="https://pageperso.lis-lab.fr/~luc.giffon/models/1536764034.66037_vgg19_omniglot.h5",
checksum="ef1272e9c7ce070e8f70889ec58d1c33"
)
}
def __init__(self, data_name, cut_layer_name=None, cut_layer_index=None):
if data_name not in self.MAP_DATA_MODEL.keys():
raise ValueError("Unknown data name. Can't load weights")
if cut_layer_name is None and cut_layer_index is None:
logger.warning(
"Cut layer chosen automatically but it eventually will lead to an error in future: block5_pool should be specified explicitly")
cut_layer_name = "block5_pool"
if cut_layer_name is not None:
transformation_name = str(data_name) + "_" + self.__class__.__name__ + "_" + str(cut_layer_name)
elif cut_layer_index is not None:
transformation_name = str(data_name) + "_" + self.__class__.__name__ \
+ "_" + str(cut_layer_index)
# todo sauvegarder index / nom dans le meme dossier si c'est les meme
else:
raise AttributeError("Cut layer name or cut_layer index must be given to init VGG19Transformer.")
self.__cut_layer_name = cut_layer_name
self.__cut_layer_index = cut_layer_index
self.keras_model = None
super().__init__(data_name=data_name,
transformation_name=transformation_name)
def load(self):
create_directory(self.s_download_dir)
s_model_path = download_data(self.MAP_DATA_MODEL[self.data_name].url, self.s_download_dir)
check_file_md5(s_model_path, self.MAP_DATA_MODEL[self.data_name].checksum)
if self.keras_model is None:
logger.debug("Loading VGG19 model for {} transformation with {} weights".format(self.transformation_name, self.data_name))
self.keras_model = load_model(s_model_path)
logger.debug("Layers of model {}".format([l.name for l in self.keras_model.layers]))
if self.__cut_layer_index is not None:
cut_layer = self.keras_model.layers[-1]
self.__cut_layer_name = cut_layer.name
logger.debug(
"Found associated layer {} to layer index {}".format(self.__cut_layer_name, self.__cut_layer_index))
self.keras_model = Model(inputs=self.keras_model.input,
outputs=self.keras_model.get_layer(name=self.__cut_layer_name).output)
else:
logger.debug("Skip loading model VGG19 for {} transformation with {} weights. Already there.".format(
self.transformation_name,
self.data_name))
if __name__ == '__main__':
valsize = 10000
d = Cifar10Dataset(validation_size=valsize)
d.load()
d.to_image()
trans = VGG19Transformer(data_name="cifar10", cut_layer_name="block5_pool")
d.apply_transformer(transformer=trans)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment