# -*- coding: utf-8 -*- import click from pathlib import Path from dotenv import find_dotenv, load_dotenv from skluc.main.utils import logger, silentremove, download_data, check_file_md5, DownloadableModel MAP_NAME_MODEL_VGG19 = { "svhn": DownloadableModel( url="https://pageperso.lis-lab.fr/~luc.giffon/models/1529968150.5454917_vgg19_svhn.h5", checksum="563a9ec2aad37459bd1ed0e329441b05" ), "cifar100": DownloadableModel( url="https://pageperso.lis-lab.fr/~luc.giffon/models/1530965727.781668_vgg19_cifar100fine.h5", checksum="edf43e263fec05e2c013dd5a2128fc38" ), "cifar10": DownloadableModel( url="https://pageperso.lis-lab.fr/~luc.giffon/models/1544802301.9379897_vgg19_Cifar10Dataset.h5", checksum="45714ca91dbbcc1904bea1a10cdcfc7a" ), "siamese_omniglot_28x28": DownloadableModel( url="https://pageperso.lis-lab.fr/~luc.giffon/models/1536244775.6502118_siamese_vgg19_omniglot_28x28_conv.h5", checksum="90aec06e688ec3248ba89544a10c9f1f" ), "omniglot_28x28": DownloadableModel( url="https://pageperso.lis-lab.fr/~luc.giffon/models/1536764034.66037_vgg19_omniglot_28x28.h5", checksum="ef1272e9c7ce070e8f70889ec58d1c33" ) } MAP_NAME_MODEL_LENET = { "mnist": DownloadableModel( url="https://pageperso.lis-lab.fr/~luc.giffon/models/1524640419.938414_lenet_mnist.h5", checksum="527d7235c213278df1d15d3fe685eb5c"), "siamese_omniglot_28x28": DownloadableModel( url="https://pageperso.lis-lab.fr/~luc.giffon/models/1536239708.891906_siamese_lenet_omniglot_conv.h5", checksum="5092edcb0be7b31b808e221afcede3e6" ), "omniglot_28x28": DownloadableModel( url="https://pageperso.lis-lab.fr/~luc.giffon/models/1536750152.6389275_lenet_omniglot_28x28.h5", checksum="c4f20b6dae0722234e1ec0bee85e3a4d" ) } MAP_NAME_MAP = { "vgg19": MAP_NAME_MODEL_VGG19, "lenet": MAP_NAME_MODEL_LENET } def _download_all_models(output_dirpath, map_name_model): for key, downloadable_model in map_name_model.items(): _download_single_model(Path(output_dirpath) / key, key, map_name_model) def _download_single_model(output_dirpath, weights, map_name_model): output_path = project_dir / output_dirpath s_model_path = download_data(map_name_model[weights].url, output_path) try: check_file_md5(s_model_path, map_name_model[weights].checksum) except ValueError: silentremove(s_model_path) def _download_single_architecture_weight(output_dirpath, architecture, weights): if weights == "all": _download_all_models(output_dirpath, MAP_NAME_MAP[architecture]) else: _download_single_model(Path(output_dirpath) / weights, weights, MAP_NAME_MAP[architecture]) def _download_all_architecture_weight(output_dirpath, weights): for key, map_name_model in MAP_NAME_MAP.items(): _download_single_architecture_weight(Path(output_dirpath) / key, weights, map_name_model) @click.command() @click.argument('architecture', default="all") @click.argument('weights', default="all") @click.argument('output_dirpath', type=click.Path()) def main(weights, architecture, output_dirpath): logger.info(f"Downloading architecture {architecture} with weights {weights} to save to {output_dirpath}") if architecture == "all": _download_all_architecture_weight(output_dirpath, weights) else: _download_single_architecture_weight(Path(output_dirpath) / architecture, architecture, weights) if __name__ == '__main__': # not used in this stub but often useful for finding various files project_dir = Path(__file__).resolve().parents[2] # find .env automagically by walking up directories until it's found, then # load up the .env entries as environment variables load_dotenv(find_dotenv()) main()