Select Git revision
-
Franck Dary authoredFranck Dary authored
test_GetMultiviewDB.py 6.55 KiB
import os
import unittest
import h5py
import numpy as np
from multiview_platform.mono_multi_view_classifiers.utils import get_multiview_db
from multiview_platform.tests.utils import rm_tmp, tmp_path
class Test_get_classic_db_hdf5(unittest.TestCase):
def setUp(self):
rm_tmp()
os.mkdir(tmp_path)
self.rs = np.random.RandomState(42)
self.nb_view = 3
self.file_name = "test.hdf5"
self.nb_examples = 5
self.nb_class = 3
self.views = [self.rs.randint(0, 10, size=(self.nb_examples, 7))
for _ in range(self.nb_view)]
self.labels = self.rs.randint(0, self.nb_class, self.nb_examples)
self.dataset_file = h5py.File(os.path.join(tmp_path, self.file_name))
self.view_names = ["ViewN" + str(index) for index in
range(len(self.views))]
self.are_sparse = [False for _ in self.views]
for view_index, (view_name, view, is_sparse) in enumerate(
zip(self.view_names, self.views, self.are_sparse)):
view_dataset = self.dataset_file.create_dataset(
"View" + str(view_index),
view.shape,
data=view)
view_dataset.attrs["name"] = view_name
view_dataset.attrs["sparse"] = is_sparse
labels_dataset = self.dataset_file.create_dataset("Labels",
shape=self.labels.shape,
data=self.labels)
self.labels_names = [str(index) for index in np.unique(self.labels)]
labels_dataset.attrs["names"] = [label_name.encode()
for label_name in self.labels_names]
meta_data_grp = self.dataset_file.create_group("Metadata")
meta_data_grp.attrs["nbView"] = len(self.views)
meta_data_grp.attrs["nbClass"] = len(np.unique(self.labels))
meta_data_grp.attrs["datasetLength"] = len(self.labels)
def test_simple(self):
dataset , labels_dictionary, dataset_name = get_multiview_db.get_classic_db_hdf5(
["ViewN2"], tmp_path, self.file_name.split(".")[0],
self.nb_class, ["0", "2"],
self.rs, path_for_new=tmp_path)
self.assertEqual(dataset.nb_view, 1)
self.assertEqual(labels_dictionary,
{0: "0", 1: "2", 2:"1"})
self.assertEqual(dataset.get_nb_examples(), 5)
self.assertEqual(len(np.unique(dataset.get_labels())), 3)
def test_all_views_asked(self):
dataset, labels_dictionary, dataset_name = get_multiview_db.get_classic_db_hdf5(
None, tmp_path, self.file_name.split(".")[0],
self.nb_class, ["0", "2"],
self.rs, path_for_new=tmp_path)
self.assertEqual(dataset.nb_view, 3)
self.assertEqual(dataset.get_view_dict(), {'ViewN0': 0, 'ViewN1': 1, 'ViewN2': 2})
def test_asked_the_whole_dataset(self):
dataset, labels_dictionary, dataset_name = get_multiview_db.get_classic_db_hdf5(
["ViewN2"], tmp_path, self.file_name.split(".")[0],
self.nb_class, ["0", "2"],
self.rs, path_for_new=tmp_path, full=True)
self.assertEqual(dataset.dataset, self.dataset_file)
def tearDown(self):
rm_tmp()
class Test_get_classic_db_csv(unittest.TestCase):
def setUp(self):
rm_tmp()
os.mkdir(tmp_path)
self.pathF = tmp_path
self.NB_CLASS = 2
self.nameDB = "test_dataset"
self.askedLabelsNames = ["test_label_1", "test_label_3"]
self.random_state = np.random.RandomState(42)
self.views = ["test_view_1", "test_view_3"]
np.savetxt(self.pathF + self.nameDB + "-labels-names.csv",
np.array(["test_label_0", "test_label_1",
"test_label_2", "test_label_3"]), fmt="%s",
delimiter=",")
np.savetxt(self.pathF + self.nameDB + "-labels.csv",
self.random_state.randint(0, 4, 10), delimiter=",")
os.mkdir(self.pathF + "Views")
self.datas = []
for i in range(4):
data = self.random_state.randint(0, 100, (10, 20))
np.savetxt(self.pathF + "Views/test_view_" + str(i) + ".csv",
data, delimiter=",")
self.datas.append(data)
def test_simple(self):
dataset, labels_dictionary, dataset_name = get_multiview_db.get_classic_db_csv(
self.views, self.pathF, self.nameDB,
self.NB_CLASS, self.askedLabelsNames,
self.random_state, delimiter=",", path_for_new=tmp_path)
self.assertEqual(dataset.nb_view, 2)
self.assertEqual(dataset.get_view_dict(), {'test_view_1': 0, 'test_view_3': 1})
self.assertEqual(labels_dictionary,
{0: "test_label_1", 1: "test_label_3"})
self.assertEqual(dataset.get_nb_examples(), 3)
self.assertEqual(dataset.get_nb_class(), 2)
@classmethod
def tearDown(self):
for i in range(4):
os.remove(
tmp_path+"Views/test_view_" + str(
i) + ".csv")
os.rmdir(tmp_path+"Views")
os.remove(
tmp_path+"test_dataset-labels-names.csv")
os.remove(tmp_path+"test_dataset-labels.csv")
os.remove(tmp_path+"test_dataset.hdf5")
os.remove(
tmp_path+"test_dataset_temp_filter.hdf5")
os.rmdir(tmp_path)
class Test_get_plausible_db_hdf5(unittest.TestCase):
@classmethod
def setUpClass(cls):
rm_tmp()
cls.path = tmp_path
cls.nb_class=3
cls.rs = np.random.RandomState(42)
cls.nb_view=3
cls.nb_examples = 5
cls.nb_features = 4
@classmethod
def tearDownClass(cls):
rm_tmp()
def test_simple(self):
dataset, labels_dict, name = get_multiview_db.get_plausible_db_hdf5(
"", self.path, "", nb_class=self.nb_class, random_state=self.rs,
nb_view=3, nb_examples=self.nb_examples,
nb_features=self.nb_features)
self.assertEqual(dataset.init_example_indces(), range(5))
self.assertEqual(dataset.get_nb_class(), self.nb_class)
def test_two_class(self):
dataset, labels_dict, name = get_multiview_db.get_plausible_db_hdf5(
"", self.path, "", nb_class=2, random_state=self.rs,
nb_view=3, nb_examples=self.nb_examples,
nb_features=self.nb_features)
self.assertEqual(dataset.init_example_indces(), range(5))
self.assertEqual(dataset.get_nb_class(), 2)
if __name__ == '__main__':
unittest.main()