"""
Benchmark VGG: Benchmarking deepstrom versus other architectures of the VGG network.

Usage:
    benchmark_vgg deepstrom [-r] [-a value] [-v size] [-e numepoch] [-s batchsize] [-D reprdim] [-m size] (-R|-L|-C|-E|-P|-S|-A|-T|-M) [-g gammavalue] [-c cvalue] [-n]

Options:
    --help -h                               Display help and exit.
    -e numepoch --num-epoch=numepoch        The number of epoch.
    -s batchsize --batch-size=batchsize     The number of example in each batch
    -v size --validation-size size          The size of the validation set [default: 10000]
    -a value --seed value                   The seed value used for all randomization processed [default: 0]
    -D reprdim --out-dim=reprdim            The dimension of the final representation
    -m size --nys-size size                 The number of example in the nystrom subsample.
    -n --non-linear                         Tell Nystrom to use the non linear activation function on its output.
    -r --real-nystrom                       Use the real w matrix
    -g gammavalue --gamma gammavalue        The value of gamma for rbf, chi or hyperbolic tangent kernel (deepstrom and deepfriedconvnet)
    -c cvalue --intercept-constant cvalue   The value of the intercept constant for the hyperbolic tangent kernel.
    -R --rbf-kernel                         Says if the rbf kernel should be used for nystrom.
    -L --linear-kernel                      Says if the linear kernel should be used for nystrom.
    -C --chi-square-kernel                  Says if the basic additive chi square kernel should be used for nystrom.
    -E --exp-chi-square-kernel              Says if the exponential chi square kernel should be used for nystrom.
    -P --chi-square-PD-kernel               Says if the Positive definite version of the basic additive chi square kernel should be used for nystrom.
    -S --sigmoid-kernel                     Says it the sigmoid kernel should be used for nystrom.
    -A --laplacian-kernel                   Says if the laplacian kernel should be used for nystrom.
    -T --stacked-kernel                     Says if the kernels laplacian, chi2 and rbf in a stacked setting should be used for nystrom.
    -M --sumed-kernel                       Says if the kernels laplacian, chi2 and rbf in a summed setting should be used for nystrom.
"""
import sys
import os
import time as t
import numpy as np
import tensorflow as tf
import docopt
from keras import Model
from keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics.pairwise import rbf_kernel, linear_kernel, additive_chi2_kernel, chi2_kernel, laplacian_kernel
import skluc.data.mldatasets as dataset
from skluc.data.transformation import VGG19Cifar10Transformer
from skluc.tensorflow_.kernel_approximation import nystrom_layer, fastfood_layer
from skluc.tensorflow_.utils import fully_connected, batch_generator, classification_cifar, conv_relu_pool, conv2d, \
    max_pool
from skluc.tensorflow_.kernel import tf_rbf_kernel, tf_linear_kernel, tf_chi_square_CPD, tf_chi_square_CPD_exp, \
    tf_chi_square_PD, tf_sigmoid_kernel, tf_laplacian_kernel, tf_stack_of_kernels, tf_sum_of_kernels
from skluc.utils import logger, log_memory_usage
import keras
from keras.models import Sequential, load_model
from keras.layers import Activation
from keras.layers import Conv2D, MaxPooling2D
from keras.initializers import he_normal
from keras.layers.normalization import BatchNormalization


def VGG19(input_shape):
    # with tf.variable_scope("block1_conv1"):
    #     weights = tf.get_variable("weights", (3, 3, 3, 64), initializer=tf.random_normal_initializer(stddev=0.1), trainable=trainable)
    #     biases = tf.get_variable("biases", (64), initializer=tf.constant_initializer(0.0), trainable=trainable)
    #     regularizer = tf.contrib.layers.l2_regularizer(scale=0.1)
    #     conv = tf.nn.conv2d(input_, weights, strides=[1, 1, 1, 1], padding='SAME', kernel_regularizer=regularizer)
    #     batch_norm = tf.nn.batch_normalization(conv, variance_epsilon=1e-3)
    #     relu = tf.nn.relu(conv + biases)
    #     tf.summary.histogram("act", relu)
    #     in order to reduce dimensionality, use bigger pooling size
        # pool = max_pool(relu, pool_size=pool_size)
    # with tf.variable_scope("conv_pool_2"):
    #     conv2 = conv_relu_pool(conv1, [5, 5, 6, 16], [16], pool_size=2, trainable=trainable)
    weight_decay = 0.0001
    # build model
    model = Sequential()

    # Block 1
    model.add(Conv2D(64, (3, 3), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer=he_normal(), name='block1_conv1', input_shape=input_shape))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Conv2D(64, (3, 3), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer=he_normal(), name='block1_conv2'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool'))

    # Block 2
    model.add(Conv2D(128, (3, 3), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer=he_normal(), name='block2_conv1'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Conv2D(128, (3, 3), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer=he_normal(), name='block2_conv2'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool'))
    #
    # Block 3
    model.add(Conv2D(256, (3, 3), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer=he_normal(), name='block3_conv1'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Conv2D(256, (3, 3), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer=he_normal(), name='block3_conv2'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Conv2D(256, (3, 3), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer=he_normal(), name='block3_conv3'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Conv2D(256, (3, 3), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer=he_normal(), name='block3_conv4'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool'))
    #
    # Block 4
    model.add(Conv2D(512, (3, 3), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer=he_normal(), name='block4_conv1'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Conv2D(512, (3, 3), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer=he_normal(), name='block4_conv2'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Conv2D(512, (3, 3), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer=he_normal(), name='block4_conv3'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Conv2D(512, (3, 3), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer=he_normal(), name='block4_conv4'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool'))

    # Block 5
    model.add(Conv2D(512, (3, 3), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer=he_normal(), name='block5_conv1'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Conv2D(512, (3, 3), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer=he_normal(), name='block5_conv2'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Conv2D(512, (3, 3), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer=he_normal(), name='block5_conv3'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Conv2D(512, (3, 3), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer=he_normal(), name='block5_conv4'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool'))

    return model


def VGG19_preload():
    logger.debug("filename: {}".format(os.path.abspath(__file__)))
    model = load_model(os.path.join(os.path.dirname(os.path.abspath(__file__)), "1522967518.1916964_vgg19_cifar10.h5"))
    vgg_conv_model = Model(inputs=model.input,
                           outputs=model.get_layer('block5_pool').output)
    return vgg_conv_model


def fct_deepstrom(input_, out_dim, subsample, kernel, kernel_params, w_matrix, non_linearity):
    """
    Wrap the computing of the deepstrom layer

    :param input_:
    :param out_dim:
    :param subsample:
    :param kernel:
    :param kernel_params:
    :return:
    """
    out_fc = nystrom_layer(input_, subsample, W_matrix=w_matrix, output_dim=out_dim, kernel=kernel, output_act=non_linearity, **kernel_params)
    return out_fc


if __name__ == '__main__':

    arguments = docopt.docopt(__doc__)
    NUM_EPOCH = int(arguments["--num-epoch"])
    BATCH_SIZE = int(arguments["--batch-size"])
    SEED_TRAIN_VALIDATION = 0
    SEED = int(arguments["--seed"])
    OUT_DIM = int(arguments["--out-dim"]) if arguments["--out-dim"] is not None else None
    VALIDATION_SIZE = int(arguments["--validation-size"])
    NYS_SUBSAMPLE_SIZE = int(arguments["--nys-size"])
    if OUT_DIM is None:
        OUT_DIM = NYS_SUBSAMPLE_SIZE
    KERNEL_NAME = None
    GAMMA = None
    CONST = None
    REAL_NYSTROM = arguments["--real-nystrom"]

    NON_LINEAR = tf.nn.relu if arguments["--non-linear"] else None

    RBF_KERNEL = arguments["--rbf-kernel"]
    LINEAR_KERNEL = arguments["--linear-kernel"]
    CHI2_KERNEL = arguments["--chi-square-kernel"]
    CHI2_EXP_KERNEL = arguments["--exp-chi-square-kernel"]
    CHI2_PD_KERNEL = arguments["--chi-square-PD-kernel"]
    SIGMOID_KERNEL = arguments["--sigmoid-kernel"]
    LAPLACIAN_KERNEL = arguments["--laplacian-kernel"]
    STACKED_KERNEL = arguments["--stacked-kernel"]
    SUMED_KERNEL = arguments["--sumed-kernel"]

    kernel_dict = {}

    data = dataset.Cifar10Dataset(validation_size=VALIDATION_SIZE, seed=SEED_TRAIN_VALIDATION)
    data.load()
    data.normalize()
    data.data_astype(np.float32)
    data.labels_astype(np.float32)
    data.to_image()
    data.to_one_hot()

    logger.debug("Start benchmark with parameters: {}".format(" ".join(sys.argv[1:])))
    logger.debug("Using dataset {} with validation size {} and seed for spliting set {}.".format(data.s_name, data.validation_size, data.seed))
    logger.debug("Shape of train set data: {}; shape of train set labels: {}".format(data.train[0].shape, data.train[1].shape))
    logger.debug("Shape of validation set data: {}; shape of validation set labels: {}".format(data.validation[0].shape, data.validation[1].shape))
    logger.debug("Shape of test set data: {}; shape of test set labels: {}".format(data.test[0].shape, data.test[1].shape))
    logger.debug("Sample of label: {}".format(data.train[1][0]))

    if RBF_KERNEL:
        KERNEL = tf_rbf_kernel
        KERNEL_NAME = "rbf"
        GAMMA = float(arguments["--gamma"])
        kernel_dict = {"gamma": GAMMA}
    elif LINEAR_KERNEL:
        KERNEL = tf_linear_kernel
        KERNEL_NAME = "linear"
    elif CHI2_KERNEL:
        KERNEL = tf_chi_square_CPD
        KERNEL_NAME = "chi2_cpd"
    elif CHI2_EXP_KERNEL:
        KERNEL = tf_chi_square_CPD_exp
        KERNEL_NAME = "chi2_exp_cpd"
        GAMMA = float(arguments["--gamma"])
        kernel_dict = {"gamma": GAMMA}
    elif CHI2_PD_KERNEL:
        KERNEL = tf_chi_square_PD
        KERNEL_NAME = "chi2_pd"
    elif SIGMOID_KERNEL:
        KERNEL = tf_sigmoid_kernel
        KERNEL_NAME = "sigmoid"
        GAMMA = float(arguments["--gamma"])
        CONST = float(arguments["--intercept-constant"])
        kernel_dict = {"gamma": GAMMA, "constant": CONST}
    elif LAPLACIAN_KERNEL:
        KERNEL = tf_laplacian_kernel
        KERNEL_NAME = "laplacian"
        GAMMA = float(arguments["--gamma"])
        kernel_dict = {"gamma": np.sqrt(GAMMA)}
    elif STACKED_KERNEL:
        # todo it doesn't work
        GAMMA = float(arguments["--gamma"])

        def KERNEL(X, Y):
            return tf_stack_of_kernels(X, Y,
                                       [tf_laplacian_kernel, tf_rbf_kernel, tf_chi_square_CPD],
                                       [{"gamma": GAMMA}, {"gamma": GAMMA}, {}])
        KERNEL_NAME = "stacked"
    elif SUMED_KERNEL:
        GAMMA = float(arguments["--gamma"])

        def KERNEL(X, Y):
            return tf_sum_of_kernels(X, Y,
                                     [tf_laplacian_kernel, tf_rbf_kernel, tf_chi_square_CPD],
                                     [{"gamma": GAMMA}, {"gamma": GAMMA}, {}])
        KERNEL_NAME = "summed"
    else:
        raise Exception("No kernel function specified for deepstrom")

    input_dim, output_dim = data.train[0].shape[1:], data.train[1].shape[1]
    with tf.Graph().as_default():
        np.random.seed(SEED)
        nys_subsample_index = np.random.permutation(data.train[0].shape[0])
        nys_subsample = data.train[0][nys_subsample_index[:NYS_SUBSAMPLE_SIZE]]

        nys_subsample_placeholder = tf.Variable(nys_subsample, dtype=tf.float32, name="nys_subsample", trainable=False)

        x = tf.placeholder(tf.float32, shape=[None, *input_dim], name="x")
        y = tf.placeholder(tf.float32, shape=[None, output_dim], name="label")
        # nys_subsample_placeholder = tf.placeholder(tf.float32, shape=[NYS_SUBSAMPLE_SIZE, *input_dim], name="nys_subsample")

        # vgg_conv_model = VGG19_preload()
        with tf.variable_scope("Convolution") as scope_convolution:
            vgg_conv_model = VGG19(input_dim)
            vgg_conv_model.trainable=False
            conv_x = vgg_conv_model(x)
            tf.summary.histogram("convolution_x", conv_x)
            vgg_conv_model_subsample = keras.Model(inputs=vgg_conv_model.inputs,
                                                   outputs=vgg_conv_model.outputs)
            vgg_conv_model_subsample.trainable = False
            conv_nys_subsample = vgg_conv_model_subsample(nys_subsample_placeholder)

        logger.debug("Selecting deepstrom layer function with "
                     "subsample size = {}, "
                     "output_dim = {}, "
                     "{} activation function "
                     "and kernel = {}"
                     .format(NYS_SUBSAMPLE_SIZE,
                             OUT_DIM,
                             "with" if NON_LINEAR else "without",
                             KERNEL_NAME))
        if OUT_DIM is not None and OUT_DIM > NYS_SUBSAMPLE_SIZE:
            logger.debug("Output dim is greater than deepstrom subsample size. Aborting.")
            # todo change this because it is copy-pasted (use function instead)

            global_acc_val = None
            global_acc_test = None
            training_time = None
            printed_r_list = [str(global_acc_val),
                              str(global_acc_test),
                              str(training_time),
                              str(NUM_EPOCH),
                              str(BATCH_SIZE),
                              str(OUT_DIM),
                              str(KERNEL_NAME),
                              str(GAMMA),
                              str(CONST),
                              str(NYS_SUBSAMPLE_SIZE),
                              str(VALIDATION_SIZE),
                              str(SEED),
                              str(NON_LINEAR),
                              ]
            print(",".join(printed_r_list))
            exit()
        w_matrix = None
        if REAL_NYSTROM:
            init_dim = np.prod([s.value for s in conv_x.shape[1:] if s.value is not None])
            h_conv_nystrom_subsample_flat = tf.reshape(conv_nys_subsample, [conv_nys_subsample.shape[0], init_dim])

            K_matrix = KERNEL(h_conv_nystrom_subsample_flat, h_conv_nystrom_subsample_flat, **kernel_dict)
            S, U, V = tf.svd(K_matrix)
            invert_root_K = tf.matmul(tf.matmul(U, tf.sqrt(tf.diag(S))), tf.transpose(V))
            w_matrix = invert_root_K

        input_classif = fct_deepstrom(conv_x, OUT_DIM, conv_nys_subsample, KERNEL, kernel_dict, w_matrix=w_matrix, non_linearity=NON_LINEAR)

        classif, keep_prob = classification_cifar(input_classif, output_dim)

        # calcul de la loss
        with tf.name_scope("xent"):
            cross_entropy = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=classif, name="xentropy"),
                name="xentropy_mean")
            tf.summary.scalar('loss-xent', cross_entropy)

        # todo learning rate as hyperparameter
        # calcul du gradient
        with tf.name_scope("train"):
            global_step = tf.Variable(0, name="global_step", trainable=False)
            train_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(cross_entropy,
                                                                                  global_step=global_step)

        # calcul de l'accuracy
        with tf.name_scope("accuracy"):
            predictions = tf.argmax(classif, 1)
            correct_prediction = tf.equal(predictions, tf.argmax(y, 1))
            accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
            tf.summary.scalar("accuracy", accuracy_op)

        merged_summary = tf.summary.merge_all()

        init = tf.global_variables_initializer()
        # Create a session for running Ops on the Graph.
        # Instantiate a SummaryWriter to output summaries and the Graph.
        # summary_writer = tf.summary.FileWriter("debug_benchmark_vgg")
        # Initialize all Variable objects
        # actual learning
        saver = tf.train.Saver()

        with tf.Session() as sess:
            logger.debug("trainable variables are: {}".format(tf.trainable_variables()))
            # summary_writer.add_graph(sess.graph)
            # Initialize all Variable objects
            datagen = ImageDataGenerator(horizontal_flip=True,
                                         width_shift_range=0.125,
                                         height_shift_range=0.125,
                                         fill_mode='constant',
                                         cval=0.)
            datagen.fit(data.train[0])
            sess.run(init)
            # actual learning
            # feed_dict_val = {x: data.validation[0], y: data.validation[1], keep_prob: 1.0}
            global_start = t.time()
            feed_dict = {nys_subsample_placeholder: nys_subsample}
            feed_dict_val = {nys_subsample_placeholder: nys_subsample}
            feed_dict_test = {nys_subsample_placeholder: nys_subsample}
            start_time_int = int(t.time())
            for i in range(NUM_EPOCH):
                saver.save(sess, os.path.abspath('end_to_end_model'), global_step=start_time_int)
                start = t.time()
                # for X_batch, Y_batch in batch_generator(data.train[0], data.train[1], BATCH_SIZE, True):
                batchgen = datagen.flow(data.train[0], data.train[1], BATCH_SIZE, shuffle=False)
                j = 0
                log_memory_usage()
                while j < len(batchgen):
                    X_batch, Y_batch = next(batchgen)
                    # batch_generator(data.train[0], data.train[1], BATCH_SIZE, True):
                    # X_batch = tf.map_fn(lambda img: datagen.random_transform(img), X_batch)
                    feed_dict.update({x: X_batch, y: Y_batch, keep_prob: 0.5})
                    _, loss, acc = sess.run([train_optimizer, cross_entropy, accuracy_op], feed_dict=feed_dict)
                    if j % 100 == 0:
                        # summary_str = sess.run(merged_summary, feed_dict=feed_dict)
                        # summary_writer.add_summary(summary_str, j)
                        logger.debug("epoch: {}/{}; batch: {}/{}; loss: {}; acc: {}".format(i, NUM_EPOCH,
                                                                                            j, int(data.train[0].shape[0]/BATCH_SIZE),
                                                                                            loss, acc))
                    j += 1

            training_time = t.time() - global_start
            accuracies_val = []
            i = 0
            for X_batch, Y_batch in batch_generator(data.validation[0], data.validation[1], 1000, False):
                feed_dict_val.update({x: X_batch, y: Y_batch, keep_prob: 1.0})
                accuracy = sess.run([accuracy_op], feed_dict=feed_dict_val)
                accuracies_val.append(accuracy[0])
                i += 1

            accuracies_test = []
            i = 0
            for X_batch, Y_batch in batch_generator(data.test[0], data.test[1], 1000, False):
                feed_dict_test.update({x: X_batch, y: Y_batch, keep_prob: 1.0})
                accuracy = sess.run([accuracy_op], feed_dict=feed_dict_test)
                accuracies_test.append(accuracy[0])
                i += 1

        global_acc_val = sum(accuracies_val) / i
        global_acc_test = sum(accuracies_test) / i
        printed_r_list = [str(global_acc_val),
                          str(global_acc_test),
                          str(training_time),
                          str(NUM_EPOCH),
                          str(BATCH_SIZE),
                          str(OUT_DIM),
                          str(KERNEL_NAME),
                          str(GAMMA),
                          str(CONST),
                          str(NYS_SUBSAMPLE_SIZE),
                          str(VALIDATION_SIZE),
                          str(SEED),
                          str(NON_LINEAR),
                          ]
        print(",".join(printed_r_list))