# -*- coding: utf-8 -*-
import click
from pathlib import Path
from dotenv import find_dotenv, load_dotenv
from skluc.main.utils import logger, silentremove, download_data, check_file_md5, DownloadableModel

MAP_NAME_MODEL_VGG19 = {
    "svhn": DownloadableModel(
        url="https://pageperso.lis-lab.fr/~luc.giffon/models/1529968150.5454917_vgg19_svhn.h5",
        checksum="563a9ec2aad37459bd1ed0e329441b05"
    ),
    "cifar100": DownloadableModel(
        url="https://pageperso.lis-lab.fr/~luc.giffon/models/1530965727.781668_vgg19_cifar100fine.h5",
        checksum="edf43e263fec05e2c013dd5a2128fc38"
    ),
    "cifar10": DownloadableModel(
        url="https://pageperso.lis-lab.fr/~luc.giffon/models/1544802301.9379897_vgg19_Cifar10Dataset.h5",
        checksum="45714ca91dbbcc1904bea1a10cdcfc7a"
    ),
    "siamese_omniglot_28x28": DownloadableModel(
        url="https://pageperso.lis-lab.fr/~luc.giffon/models/1536244775.6502118_siamese_vgg19_omniglot_28x28_conv.h5",
        checksum="90aec06e688ec3248ba89544a10c9f1f"
    ),
    "omniglot_28x28": DownloadableModel(
        url="https://pageperso.lis-lab.fr/~luc.giffon/models/1536764034.66037_vgg19_omniglot_28x28.h5",
        checksum="ef1272e9c7ce070e8f70889ec58d1c33"
    )
}


MAP_NAME_MODEL_LENET = {
    "mnist": DownloadableModel(
        url="https://pageperso.lis-lab.fr/~luc.giffon/models/1524640419.938414_lenet_mnist.h5",
        checksum="527d7235c213278df1d15d3fe685eb5c"),
    "siamese_omniglot_28x28": DownloadableModel(
        url="https://pageperso.lis-lab.fr/~luc.giffon/models/1536239708.891906_siamese_lenet_omniglot_conv.h5",
        checksum="5092edcb0be7b31b808e221afcede3e6"
    ),
    "omniglot_28x28": DownloadableModel(
        url="https://pageperso.lis-lab.fr/~luc.giffon/models/1536750152.6389275_lenet_omniglot_28x28.h5",
        checksum="c4f20b6dae0722234e1ec0bee85e3a4d"
    )
}



MAP_NAME_MAP = {
    "vgg19": MAP_NAME_MODEL_VGG19,
    "lenet": MAP_NAME_MODEL_LENET
}

def _download_all_models(output_dirpath, map_name_model):
    for key, downloadable_model in map_name_model.items():
        _download_single_model(Path(output_dirpath) / key, key, map_name_model)


def _download_single_model(output_dirpath, weights, map_name_model):
    output_path = project_dir / output_dirpath
    s_model_path = download_data(map_name_model[weights].url, output_path)
    try:
        check_file_md5(s_model_path, map_name_model[weights].checksum)
    except ValueError:
        silentremove(s_model_path)

def _download_single_architecture_weight(output_dirpath, architecture, weights):
    if weights == "all":
        _download_all_models(output_dirpath, MAP_NAME_MAP[architecture])
    else:
        _download_single_model(Path(output_dirpath) / weights, weights, MAP_NAME_MAP[architecture])

def _download_all_architecture_weight(output_dirpath, weights):
    for key, map_name_model in MAP_NAME_MAP.items():
        _download_single_architecture_weight(Path(output_dirpath) / key, weights, map_name_model)


@click.command()
@click.argument('architecture', default="all")
@click.argument('weights', default="all")
@click.argument('output_dirpath', type=click.Path())
def main(weights, architecture, output_dirpath):
    logger.info(f"Downloading architecture {architecture} with weights {weights} to save to {output_dirpath}")

    if architecture == "all":
        _download_all_architecture_weight(output_dirpath, weights)
    else:
        _download_single_architecture_weight(Path(output_dirpath) / architecture, architecture, weights)

if __name__ == '__main__':
    # not used in this stub but often useful for finding various files
    project_dir = Path(__file__).resolve().parents[2]

    # find .env automagically by walking up directories until it's found, then
    # load up the .env entries as environment variables
    load_dotenv(find_dotenv())

    main()