import unittest
import tensorflow as tf
import numpy as np
from sklearn.metrics.pairwise import rbf_kernel

import skluc.mldatasets as dataset

from main.nystrom.nystrom_approx import tf_rbf_kernel


class TestNystrom(unittest.TestCase):
    def setUp(self):
        mnist = dataset.MnistDataset()
        mnist = mnist.load()
        X_train, Y_train = mnist["train"]
        X_train = np.array(X_train / 255)
        X_test, Y_test = mnist["test"]
        X_test = np.array(X_test / 255)
        X_train = X_train.astype(np.float32)
        permut = np.random.permutation(X_train.shape[0])
        val_size = 5000
        X_val = X_train[permut[:val_size]]
        X_train = X_train[permut[val_size:]]
        Y_val = Y_train[permut[:val_size]]
        Y_train = Y_train[permut[val_size:]]
        X_test = X_test.astype(np.float32)
        Y_train = Y_train.astype(np.float32)
        Y_test = Y_test.astype(np.float32)

        self.X_val = X_val
        self.Y_val = Y_val
        self.X_train = X_train
        self.Y_train = Y_train
        self.X_test = X_test
        self.Y_test = Y_test

        # todo retirer ça
        self.X_val = self.X_val[:100]

        self.sess = tf.InteractiveSession()

    def test_tf_rbf_kernel(self):
        gamma = 0.01
        expected_rbf_kernel = rbf_kernel(self.X_val, self.X_val, gamma=gamma)
        obtained_rbf_kernel = tf_rbf_kernel(self.X_val, self.X_val, gamma=gamma).eval()
        difference_rbf_kernel = np.linalg.norm(expected_rbf_kernel - obtained_rbf_kernel)
        self.assertAlmostEqual(difference_rbf_kernel, 0, delta=1e-5)

        example1 = self.X_val[0].reshape((1, -1))
        example2 = self.X_val[1].reshape((1, -1))
        expected_rbf_kernel_value = rbf_kernel(example1, example2, gamma=gamma)
        obtained_rbf_kernel_value = tf_rbf_kernel(example1, example2, gamma=gamma).eval()
        difference_rbf_kernel_value = np.linalg.norm(expected_rbf_kernel_value - obtained_rbf_kernel_value)
        self.assertAlmostEqual(difference_rbf_kernel_value, 0, delta=1e-5)

        expected_rbf_kernel_vector = rbf_kernel(example1, self.X_val, gamma=gamma)
        obtained_rbf_kernel_vector = tf_rbf_kernel(example1, self.X_val, gamma=gamma).eval()
        difference_rbf_kernel_vector = np.linalg.norm(expected_rbf_kernel_vector - obtained_rbf_kernel_vector)
        self.assertAlmostEqual(difference_rbf_kernel_vector, 0, delta=1e-5)

    def tearDown(self):
        self.sess.close()


if __name__ == '__main__':
    unittest.main()