"""
Convnet with nystrom approximation of the feature map.

"""
import time as t

import tensorflow as tf
import numpy as np

import skluc.mldatasets as dataset
from skluc.neural_networks import get_next_batch, classification_mnist, convolution_mnist

tf.logging.set_verbosity(tf.logging.ERROR)

val_size = 5000
mnist = dataset.MnistDataset(validation_size=val_size)
mnist.load()
mnist.to_one_hot()
mnist.normalize()
mnist.data_astype(np.float32)
mnist.labels_astype(np.float32)

X_train, Y_train = mnist.train
X_val, Y_val = mnist.validation
X_test, Y_test = mnist.test


def tf_rbf_kernel(X, Y, gamma):
    r1 = tf.reduce_sum(X * X, axis=1)
    r1 = tf.reshape(r1, [-1, 1])
    r2 = tf.reduce_sum(Y * Y, axis=1)
    r2 = tf.reshape(r2, [1, -1])
    K = tf.matmul(X, tf.transpose(Y))
    K = r1 - 2 * K + r2
    K *= -gamma
    K = tf.exp(K)
    return K


def nystrom_layer(input_x, input_subsample, gamma, output_dim):
    nystrom_sample_size = input_subsample.shape[0]
    with tf.name_scope("nystrom"):
        init_dim = np.prod([s.value for s in input_x.shape[1:] if s.value is not None])
        h_conv_flat = tf.reshape(input_x, [-1, init_dim])
        h_conv_nystrom_subsample_flat = tf.reshape(input_subsample, [nystrom_sample_size, init_dim])
        with tf.name_scope("kernel_vec"):
            kernel_vector = tf_rbf_kernel(h_conv_flat, h_conv_nystrom_subsample_flat, gamma=gamma)

        # this is the initial formulation given by sklearn
        # D = tf.get_variable("D", [nystrom_sample_size,], initializer=tf.random_normal_initializer(stddev=0.1))
        # V = tf.get_variable("V", [nystrom_sample_size, nystrom_sample_size],
        # initializer=tf.random_normal_initializer(stddev=0.1))
        # out_fc = tf.matmul(kernel_vector, tf.matmul(tf.multiply(D, V), tf.transpose(V)))

        # this is simpler
        W = tf.get_variable("W", [nystrom_sample_size, output_dim],
                            initializer=tf.random_normal_initializer(stddev=0.1))
        out_fc = tf.matmul(kernel_vector, W)

    return out_fc


def main():
    NYSTROM_SAMPLE_SIZE = 100
    X_nystrom = X_train[np.random.permutation(NYSTROM_SAMPLE_SIZE)]
    GAMMA = 0.001
    print("Gamma = {}".format(GAMMA))

    with tf.Graph().as_default():
        input_dim, output_dim = X_train.shape[1], Y_train.shape[1]

        x = tf.placeholder(tf.float32, shape=[None, input_dim], name="x")
        x_nystrom = tf.Variable(X_nystrom, name="nystrom_subsample", trainable=False)
        y_ = tf.placeholder(tf.float32, shape=[None, output_dim], name="labels")

        # side size is width or height of the images
        side_size = int(np.sqrt(input_dim))
        x_image = tf.reshape(x, [-1, side_size, side_size, 1])
        x_nystrom_image = tf.reshape(x_nystrom, [NYSTROM_SAMPLE_SIZE, side_size, side_size, 1])
        tf.summary.image("digit", x_image, max_outputs=3)

        # Representation layer
        with tf.variable_scope("convolution_mnist") as scope_conv_mnist:
            h_conv = convolution_mnist(x_image)
            scope_conv_mnist.reuse_variables()
            h_conv_nystrom_subsample = convolution_mnist(x_nystrom_image, trainable=False)

        out_fc = nystrom_layer(h_conv, h_conv_nystrom_subsample, GAMMA, NYSTROM_SAMPLE_SIZE)

        y_conv, keep_prob = classification_mnist(out_fc, output_dim=output_dim)

        # # calcul de la loss
        with tf.name_scope("xent"):
            cross_entropy = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv, name="xentropy"),
                name="xentropy_mean")
            tf.summary.scalar('loss-xent', cross_entropy)

        # # calcul du gradient
        with tf.name_scope("train"):
            global_step = tf.Variable(0, name="global_step", trainable=False)
            train_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(cross_entropy, global_step=global_step)

        # # calcul de l'accuracy
        with tf.name_scope("accuracy"):
            predictions = tf.argmax(y_conv, 1)
            correct_prediction = tf.equal(predictions, tf.argmax(y_, 1))
            accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
            tf.summary.scalar("accuracy", accuracy)

        merged_summary = tf.summary.merge_all()

        init = tf.global_variables_initializer()
        # Create a session for running Ops on the Graph.
        sess = tf.Session()
        # Instantiate a SummaryWriter to output summaries and the Graph.
        summary_writer = tf.summary.FileWriter("results_deepfried_stacked")
        summary_writer.add_graph(sess.graph)
        # Initialize all Variable objects
        sess.run(init)
        # actual learning
        started = t.time()
        feed_dict_val = {x: X_val, y_: Y_val, keep_prob: 1.0}
        for i in range(10000):
            X_batch = get_next_batch(X_train, i, 64)
            Y_batch = get_next_batch(Y_train, i, 64)
            feed_dict = {x: X_batch, y_: Y_batch, keep_prob: 0.5}
            # le _ est pour capturer le retour de "train_optimizer" qu'il faut appeler
            # pour calculer le gradient mais dont l'output ne nous interesse pas
            _, loss, y_result, x_exp = sess.run([train_optimizer, cross_entropy, y_conv, x_image], feed_dict=feed_dict)
            if i % 100 == 0:
                print('step {}, loss {} (with dropout)'.format(i, loss))
                r_accuracy = sess.run([accuracy], feed_dict=feed_dict_val)
                print("accuracy: {} on validation set (without dropout).".format(r_accuracy))
                summary_str = sess.run(merged_summary, feed_dict=feed_dict)
                summary_writer.add_summary(summary_str, i)

        stoped = t.time()
        accuracy, preds = sess.run([accuracy, predictions], feed_dict={
            x: X_test, y_: Y_test, keep_prob: 1.0})
        print('test accuracy %g' % accuracy)
        np.set_printoptions(threshold=np.nan)
        print("Prediction sample: " + str(preds[:50]))
        print("Actual values: " + str(np.argmax(Y_test[:50], axis=1)))
        print("Elapsed time: %.4f s" % (stoped - started))


if __name__ == '__main__':
    main()