import os
import unittest

import h5py
import numpy as np

from ...mono_multi_view_classifiers.utils import get_multiview_db
from ..utils import rm_tmp, tmp_path, test_dataset

class Test_get_classic_db_hdf5(unittest.TestCase):

    def setUp(self):
        rm_tmp()
        os.mkdir(tmp_path)
        self.rs = np.random.RandomState(42)
        self.nb_view = 3
        self.file_name = "test.hdf5"
        self.nb_examples = 5
        self.nb_class = 3
        self.views = [self.rs.randint(0, 10, size=(self.nb_examples, 7))
                      for _ in range(self.nb_view)]
        self.labels = self.rs.randint(0, self.nb_class, self.nb_examples)
        self.dataset_file = h5py.File(os.path.join(tmp_path, self.file_name))
        self.view_names = ["ViewN" + str(index) for index in
                           range(len(self.views))]
        self.are_sparse = [False for _ in self.views]
        for view_index, (view_name, view, is_sparse) in enumerate(
                zip(self.view_names, self.views, self.are_sparse)):
            view_dataset = self.dataset_file.create_dataset(
                "View" + str(view_index),
                view.shape,
                data=view)
            view_dataset.attrs["name"] = view_name
            view_dataset.attrs["sparse"] = is_sparse
        labels_dataset = self.dataset_file.create_dataset("Labels",
                                                          shape=self.labels.shape,
                                                          data=self.labels)
        self.labels_names = [str(index) for index in np.unique(self.labels)]
        labels_dataset.attrs["names"] = [label_name.encode()
                                         for label_name in self.labels_names]
        meta_data_grp = self.dataset_file.create_group("Metadata")
        meta_data_grp.attrs["nbView"] = len(self.views)
        meta_data_grp.attrs["nbClass"] = len(np.unique(self.labels))
        meta_data_grp.attrs["datasetLength"] = len(self.labels)

    def test_simple(self):
        dataset , labels_dictionary, dataset_name = get_multiview_db.get_classic_db_hdf5(
            ["ViewN2"], tmp_path, self.file_name.split(".")[0],
            self.nb_class, ["0", "2"],
            self.rs, path_for_new=tmp_path)
        self.assertEqual(dataset.nb_view, 1)
        self.assertEqual(labels_dictionary,
                         {0: "0", 1: "2", 2:"1"})
        self.assertEqual(dataset.get_nb_examples(), 5)
        self.assertEqual(len(np.unique(dataset.get_labels())), 3)


    def test_all_views_asked(self):
        dataset, labels_dictionary, dataset_name = get_multiview_db.get_classic_db_hdf5(
            None, tmp_path, self.file_name.split(".")[0],
            self.nb_class, ["0", "2"],
            self.rs, path_for_new=tmp_path)
        self.assertEqual(dataset.nb_view, 3)
        self.assertEqual(dataset.get_view_dict(), {'ViewN0': 0, 'ViewN1': 1, 'ViewN2': 2})

    def test_asked_the_whole_dataset(self):
        dataset, labels_dictionary, dataset_name = get_multiview_db.get_classic_db_hdf5(
            ["ViewN2"], tmp_path, self.file_name.split(".")[0],
            self.nb_class, ["0", "2"],
            self.rs, path_for_new=tmp_path, full=True)
        self.assertEqual(dataset.dataset, self.dataset_file)

    def tearDown(self):
        rm_tmp()


class Test_get_classic_db_csv(unittest.TestCase):

    def setUp(self):
        rm_tmp()
        os.mkdir(tmp_path)
        self.pathF = tmp_path
        self.NB_CLASS = 2
        self.nameDB = "test_dataset"
        self.askedLabelsNames = ["test_label_1", "test_label_3"]
        self.random_state = np.random.RandomState(42)
        self.views = ["test_view_1", "test_view_3"]
        np.savetxt(self.pathF + self.nameDB + "-labels-names.csv",
                   np.array(["test_label_0", "test_label_1",
                             "test_label_2", "test_label_3"]), fmt="%s",
                   delimiter=",")
        np.savetxt(self.pathF + self.nameDB + "-labels.csv",
                   self.random_state.randint(0, 4, 10), delimiter=",")
        os.mkdir(self.pathF + "Views")
        self.datas = []
        for i in range(4):
            data = self.random_state.randint(0, 100, (10, 20))
            np.savetxt(self.pathF + "Views/test_view_" + str(i) + ".csv",
                       data, delimiter=",")
            self.datas.append(data)


    def test_simple(self):
        dataset, labels_dictionary, dataset_name = get_multiview_db.get_classic_db_csv(
            self.views, self.pathF, self.nameDB,
            self.NB_CLASS, self.askedLabelsNames,
            self.random_state, delimiter=",", path_for_new=tmp_path)
        self.assertEqual(dataset.nb_view, 2)
        self.assertEqual(dataset.get_view_dict(), {'test_view_1': 0, 'test_view_3': 1})
        self.assertEqual(labels_dictionary,
                         {0: "test_label_1", 1: "test_label_3"})
        self.assertEqual(dataset.get_nb_examples(), 3)
        self.assertEqual(dataset.get_nb_class(), 2)


    @classmethod
    def tearDown(self):
        for i in range(4):
            os.remove(
                tmp_path+"Views/test_view_" + str(
                    i) + ".csv")
        os.rmdir(tmp_path+"Views")
        os.remove(
            tmp_path+"test_dataset-labels-names.csv")
        os.remove(tmp_path+"test_dataset-labels.csv")
        os.remove(tmp_path+"test_dataset.hdf5")
        os.remove(
            tmp_path+"test_dataset_temp_filter.hdf5")
        os.rmdir(tmp_path)

class Test_get_plausible_db_hdf5(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        rm_tmp()
        cls.path = tmp_path
        cls.nb_class=3
        cls.rs = np.random.RandomState(42)
        cls.nb_view=3
        cls.nb_examples = 5
        cls.nb_features = 4

    @classmethod
    def tearDownClass(cls):
        rm_tmp()

    def test_simple(self):
        dataset, labels_dict, name = get_multiview_db.get_plausible_db_hdf5(
            "", self.path, "", nb_class=self.nb_class, random_state=self.rs,
            nb_view=3, nb_examples=self.nb_examples,
            nb_features=self.nb_features)
        self.assertEqual(dataset.init_example_indces(), range(5))
        self.assertEqual(dataset.get_nb_class(), self.nb_class)

    def test_two_class(self):
        dataset, labels_dict, name = get_multiview_db.get_plausible_db_hdf5(
            "", self.path, "", nb_class=2, random_state=self.rs,
            nb_view=3, nb_examples=self.nb_examples,
            nb_features=self.nb_features)
        self.assertEqual(dataset.init_example_indces(), range(5))
        self.assertEqual(dataset.get_nb_class(), 2)


if __name__ == '__main__':
    unittest.main()