Skip to content
Snippets Groups Projects
Commit 296d9c1c authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Added tests for hdf5 classic dataset loader, need to figure how to use only...

Added tests for hdf5 classic dataset loader, need to figure how to use only the classe's setUp and not all
parent db1ea059
Branches
Tags
No related merge requests found
......@@ -81,6 +81,7 @@ def getPlausibleDBhdf5(features, pathF, name, NB_CLASS, LABELS_NAME, nbView=3,
labelsDset = datasetFile.create_dataset("Labels", CLASS_LABELS.shape)
labelsDset[...] = CLASS_LABELS
labelsDset.attrs["name"] = "Labels"
labelsDset.attrs["names"] = ["No", "Yes"]
metaDataGrp = datasetFile.create_group("Metadata")
metaDataGrp.attrs["nbView"] = nbView
metaDataGrp.attrs["nbClass"] = 2
......@@ -145,42 +146,128 @@ def getFakeDBhdf5(features, pathF, name, NB_CLASS, LABELS_NAME, randomState):
class DatasetError(Exception):
pass
def __init__(self, *args, **kwargs):
Exception.__init__(self, *args, **kwargs)
def getClasses(labels):
nbLabels = len(set(list(labels)))
labelsSet = set(list(labels))
nbLabels = len(labelsSet)
if nbLabels >= 2:
return set(list(labels))
return labelsSet
else:
raise(DatasetError, "Dataset must have at least two different labels")
raise DatasetError("Dataset must have at least two different labels")
def getClassicDBhdf5(views, pathF, nameDB, NB_CLASS, askedLabelsNames):
"""Used to load a hdf5 database"""
datasetFile = h5py.File(pathF + nameDB + ".hdf5", "r")
fullLabels = datasetFile.get("Labels")
temp_dataset = h5py.File("../Data/temp_"+nameDB+".hdf5", "w")
labelsSet = getClasses(fullLabels)
if len(labelsSet) > 2:
labelsNames = list(datasetFile.get("Labels").attrs["names"])
usedLabels = [labelsNames.index(askedLabelName) for askedLabelName in askedLabelsNames]
def allAskedLabelsAreAvailable(askedLabelsNamesSet, availableLabelsNames):
for askedLabelName in askedLabelsNamesSet:
if askedLabelName in availableLabelsNames:
pass
else:
return False
return True
def fillLabelNames(NB_CLASS, askedLabelsNames, randomState, availableLabelsNames):
if len(askedLabelsNames) < NB_CLASS:
nbLabelsToAdd = NB_CLASS-len(askedLabelsNames)
labelsNamesToChoose = [availableLabelName for availableLabelName in availableLabelsNames
if availableLabelName not in askedLabelsNames]
addedLabelsNames = randomState.choice(labelsNamesToChoose, nbLabelsToAdd, replace=False)
askedLabelsNames = list(askedLabelsNames) + list(addedLabelsNames)
askedLabelsNamesSet = set(askedLabelsNames)
elif len(askedLabelsNames) > NB_CLASS:
askedLabelsNames = list(randomState.choice(askedLabelsNames, NB_CLASS, replace=False))
askedLabelsNamesSet = set(askedLabelsNames)
else:
askedLabelsNamesSet = set(askedLabelsNames)
return askedLabelsNames, askedLabelsNamesSet
def getAllLabels(fullLabels, availableLabelsNames):
newLabels = fullLabels
newLabelsNames = availableLabelsNames
usedIndices = np.arange(len(fullLabels))
return newLabels, newLabelsNames, usedIndices
def selectAskedLabels(askedLabelsNamesSet, availableLabelsNames, askedLabelsNames, fullLabels):
if allAskedLabelsAreAvailable(askedLabelsNamesSet, availableLabelsNames):
usedLabels = [availableLabelsNames.index(askedLabelName) for askedLabelName in askedLabelsNames]
usedIndices = np.array([labelIndex for labelIndex, label in enumerate(fullLabels) if label in usedLabels])
newLabels = np.array([usedLabels.index(label) for label in fullLabels if label in usedLabels])
newLabelsNames = [availableLabelsNames[usedLabel] for usedLabel in usedLabels]
return newLabels, newLabelsNames, usedIndices
else:
usedIndices = np.arange(fullLabels.shape)
raise DatasetError("Asked labels are not all available in the dataset")
def filterLabels(labelsSet, askedLabelsNamesSet, fullLabels, availableLabelsNames, askedLabelsNames):
if len(labelsSet) > 2:
if askedLabelsNames == availableLabelsNames:
newLabels, newLabelsNames, usedIndices = getAllLabels(fullLabels, availableLabelsNames)
elif len(askedLabelsNamesSet) <= len(labelsSet):
newLabels, newLabelsNames, usedIndices = selectAskedLabels(askedLabelsNamesSet, availableLabelsNames,
askedLabelsNames, fullLabels)
else:
raise DatasetError("Asked more labels than available in the dataset. Available labels are : "+
", ".join(availableLabelsNames))
else:
newLabels, newLabelsNames, usedIndices = getAllLabels(fullLabels, availableLabelsNames)
return newLabels, newLabelsNames, usedIndices
def filterViews(datasetFile, temp_dataset, views, usedIndices):
newViewIndex = 0
for viewIndex in range(datasetFile.get("Metadata").attrs["nbView"]):
if datasetFile.get("View" + str(viewIndex)).attrs["name"] in views:
datasetFile.copy("View"+str(viewIndex), temp_dataset)
copyhdf5Dataset(datasetFile, temp_dataset, "View" + str(viewIndex), "View" + str(newViewIndex), usedIndices)
newViewIndex += 1
else:
pass
temp_dataset.get("Metadata").attrs["nbView"] = len(views)
def getClassicDBhdf5(views, pathF, nameDB, NB_CLASS, askedLabelsNames, randomState):
"""Used to load a hdf5 database"""
datasetFile = h5py.File(pathF + nameDB + ".hdf5", "r")
fullLabels = datasetFile.get("Labels").value
temp_dataset = h5py.File(pathF+nameDB+"_temp_view_label_select.hdf5", "w")
datasetFile.copy("Metadata", temp_dataset)
labelsSet = getClasses(fullLabels)
availableLabelsNames = list(datasetFile.get("Labels").attrs["names"])
askedLabelsNames, askedLabelsNamesSet = fillLabelNames(NB_CLASS, askedLabelsNames,
randomState, availableLabelsNames)
newLabels, newLabelsNames, usedIndices = filterLabels(labelsSet, askedLabelsNamesSet, fullLabels,
availableLabelsNames, askedLabelsNames)
temp_dataset.get("Metadata").attrs["datasetLength"] = len(usedIndices)
temp_dataset.create_dataset("Labels", data=newLabels)
temp_dataset.get("Labels").attrs["names"] = newLabelsNames
filterViews(datasetFile, temp_dataset, views, usedIndices)
labelsDictionary = dict((labelIndex, labelName) for labelIndex, labelName in
zip(fullLabels.attrs["labels_indices"], fullLabels.attrs["labels"]))
enumerate(temp_dataset.get("Labels").attrs["names"]))
return datasetFile, labelsDictionary
def copyhdf5Dataset(sourceDataFile, destinationDataFile, sourceDatasetName, destinationDatasetName, usedIndices):
"""Used to copy a view in a new dataset file using only the examples of usedIndices, and copying the args"""
newDset = destinationDataFile.create_dataset(destinationDatasetName,
data=sourceDataFile.get(sourceDatasetName).value[usedIndices,:])
if "sparse" in sourceDataFile.get(sourceDatasetName).attrs.keys() and sourceDataFile.get(sourceDatasetName).attrs["sparse"]:
# TODO : Support sparse
pass
else:
for key, value in sourceDataFile.get(sourceDatasetName).attrs.items():
newDset.attrs[key] = value
# def getLabelSupports(CLASS_LABELS):
# """Used to get the number of example for each label"""
# labels = set(CLASS_LABELS)
......
......@@ -39,7 +39,7 @@ def parseTheArgs(arguments):
groupClass = parser.add_argument_group('Classification arguments')
groupClass.add_argument('--CL_multiclassMethod', metavar='STRING', action='store',
help='Determine which multiclass method to use if the dataset is multiclass', type=float,
help='Determine which multiclass method to use if the dataset is multiclass',
default="oneVersusOne")
groupClass.add_argument('--CL_split', metavar='FLOAT', action='store',
help='Determine the split ratio between learning and validation sets', type=float,
......
import unittest
import h5py
import numpy as np
import os
from ...MonoMultiViewClassifiers.utils import GetMultiviewDb
class Test_copyhdf5Dataset(unittest.TestCase):
def setUp(self):
self.random_state = np.random.RandomState(42)
if not os.path.exists("Code/Tests/temp_tests"):
os.mkdir("Code/Tests/temp_tests")
self.dataset_file = h5py.File("Code/Tests/temp_tests/test_copy.hdf5", "w")
self.dataset = self.dataset_file.create_dataset("test", data=self.random_state.randint(0,100,(10,20)))
self.dataset.attrs["test_arg"] = "Am I copied"
def test_simple_copy(self):
GetMultiviewDb.copyhdf5Dataset(self.dataset_file, self.dataset_file, "test", "test_copy", np.arange(10))
np.testing.assert_array_equal(self.dataset_file.get("test").value, self.dataset_file.get("test_copy").value)
self.assertEqual("Am I copied", self.dataset_file.get("test_copy").attrs["test_arg"])
def test_copy_only_some_indices(self):
usedIndices = self.random_state.choice(10,6, replace=False)
GetMultiviewDb.copyhdf5Dataset(self.dataset_file, self.dataset_file, "test", "test_copy", usedIndices)
np.testing.assert_array_equal(self.dataset_file.get("test").value[usedIndices, :], self.dataset_file.get("test_copy").value)
self.assertEqual("Am I copied", self.dataset_file.get("test_copy").attrs["test_arg"])
def tearDown(self):
os.remove("Code/Tests/temp_tests/test_copy.hdf5")
os.rmdir("Code/Tests/temp_tests")
class Test_filterViews(unittest.TestCase):
def setUp(self):
self.random_state = np.random.RandomState(42)
self.views = ["test_view_1", "test_view_2"]
if not os.path.exists("Code/Tests/temp_tests"):
os.mkdir("Code/Tests/temp_tests")
self.dataset_file = h5py.File("Code/Tests/temp_tests/test_copy.hdf5", "w")
self.metadata_group = self.dataset_file.create_group("Metadata")
self.metadata_group.attrs["nbView"] = 4
for i in range(4):
self.dataset = self.dataset_file.create_dataset("View"+str(i),
data=self.random_state.randint(0, 100, (10, 20)))
self.dataset.attrs["name"] = "test_view_"+str(i)
self.temp_dataset_file = h5py.File("Code/Tests/temp_tests/test_copy_temp.hdf5", "w")
self.dataset_file.copy("Metadata", self.temp_dataset_file)
def test_simple_filter(self):
GetMultiviewDb.filterViews(self.dataset_file, self.temp_dataset_file, self.views, np.arange(10))
self.assertEqual(self.dataset_file.get("View1").attrs["name"],
self.temp_dataset_file.get("View0").attrs["name"])
np.testing.assert_array_equal(self.dataset_file.get("View2").value, self.temp_dataset_file.get("View1").value)
self.assertEqual(self.temp_dataset_file.get("Metadata").attrs["nbView"], 2)
def test_filter_view_and_examples(self):
usedIndices = self.random_state.choice(10, 6, replace=False)
GetMultiviewDb.filterViews(self.dataset_file, self.temp_dataset_file, self.views, usedIndices)
np.testing.assert_array_equal(self.dataset_file.get("View1").value[usedIndices, :],
self.temp_dataset_file.get("View0").value)
def tearDown(self):
os.remove("Code/Tests/temp_tests/test_copy.hdf5")
os.remove("Code/Tests/temp_tests/test_copy_temp.hdf5")
os.rmdir("Code/Tests/temp_tests")
class Test_filterLabels(unittest.TestCase):
def setUp(self):
self.random_state = np.random.RandomState(42)
self.labelsSet = set(range(4))
self.askedLabelsNamesSet = {"test_label_1", "test_label_3"}
self.fullLabels = self.random_state.randint(0,4,10)
self.availableLabelsNames = ["test_label_0", "test_label_1", "test_label_2", "test_label_3"]
self.askedLabelsNames = ["test_label_1", "test_label_3"]
def test_simple(self):
newLabels, \
newLabelsNames, \
usedIndices = GetMultiviewDb.filterLabels(self.labelsSet,
self.askedLabelsNamesSet,
self.fullLabels,
self.availableLabelsNames,
self.askedLabelsNames)
self.assertEqual(["test_label_1", "test_label_3"], newLabelsNames)
np.testing.assert_array_equal(usedIndices, np.array([1, 5, 9]))
np.testing.assert_array_equal(newLabels, np.array([1,1,0]))
def test_biclasse(self):
self.labelsSet = {0,1}
self.fullLabels = self.random_state.randint(0,2,10)
self.availableLabelsNames = ["test_label_0", "test_label_1"]
newLabels, \
newLabelsNames, \
usedIndices = GetMultiviewDb.filterLabels(self.labelsSet,
self.askedLabelsNamesSet,
self.fullLabels,
self.availableLabelsNames,
self.askedLabelsNames)
self.assertEqual(self.availableLabelsNames, newLabelsNames)
np.testing.assert_array_equal(usedIndices, np.arange(10))
np.testing.assert_array_equal(newLabels, self.fullLabels)
def test_asked_too_many_labels(self):
self.askedLabelsNamesSet = {"test_label_0", "test_label_1", "test_label_2", "test_label_3", "chicken_is_heaven"}
with self.assertRaises(GetMultiviewDb.DatasetError) as catcher:
GetMultiviewDb.filterLabels(self.labelsSet,
self.askedLabelsNamesSet,
self.fullLabels,
self.availableLabelsNames,
self.askedLabelsNames)
exception = catcher.exception
self.assertTrue("Asked more labels than available in the dataset. Available labels are : test_label_0, test_label_1, test_label_2, test_label_3" in exception)
def test_asked_all_labels(self):
self.askedLabelsNamesSet = {"test_label_0", "test_label_1", "test_label_2", "test_label_3"}
self.askedLabelsNames = ["test_label_0", "test_label_1", "test_label_2", "test_label_3"]
newLabels, \
newLabelsNames, \
usedIndices = GetMultiviewDb.filterLabels(self.labelsSet,
self.askedLabelsNamesSet,
self.fullLabels,
self.availableLabelsNames,
self.askedLabelsNames)
self.assertEqual(self.availableLabelsNames, newLabelsNames)
np.testing.assert_array_equal(usedIndices, np.arange(10))
np.testing.assert_array_equal(newLabels, self.fullLabels)
class Test_selectAskedLabels(unittest.TestCase):
def setUp(self):
self.random_state = np.random.RandomState(42)
self.askedLabelsNamesSet = {"test_label_1", "test_label_3"}
self.fullLabels = self.random_state.randint(0, 4, 10)
self.availableLabelsNames = ["test_label_0", "test_label_1", "test_label_2", "test_label_3"]
self.askedLabelsNames = ["test_label_1", "test_label_3"]
def test_simple(self):
newLabels, \
newLabelsNames, \
usedIndices = GetMultiviewDb.selectAskedLabels(self.askedLabelsNamesSet,
self.availableLabelsNames,
self.askedLabelsNames,
self.fullLabels)
self.assertEqual(["test_label_1", "test_label_3"], newLabelsNames)
np.testing.assert_array_equal(usedIndices, np.array([1, 5, 9]))
np.testing.assert_array_equal(newLabels, np.array([1, 1, 0]))
def test_asked_all_labels(self):
self.askedLabelsNamesSet = {"test_label_0", "test_label_1", "test_label_2", "test_label_3"}
self.askedLabelsNames = ["test_label_0", "test_label_1", "test_label_2", "test_label_3"]
newLabels, \
newLabelsNames, \
usedIndices = GetMultiviewDb.selectAskedLabels(self.askedLabelsNamesSet,
self.availableLabelsNames,
self.askedLabelsNames,
self.fullLabels)
self.assertEqual(self.availableLabelsNames, newLabelsNames)
np.testing.assert_array_equal(usedIndices, np.arange(10))
np.testing.assert_array_equal(newLabels, self.fullLabels)
def test_asked_unavailable_labels(self):
self.askedLabelsNamesSet = {"test_label_1", "test_label_3", "chicken_is_heaven"}
with self.assertRaises(GetMultiviewDb.DatasetError) as catcher:
GetMultiviewDb.selectAskedLabels(self.askedLabelsNamesSet,
self.availableLabelsNames,
self.askedLabelsNames,
self.fullLabels)
exception = catcher.exception
self.assertTrue("Asked labels are not all available in the dataset" in exception)
class Test_getAllLabels(unittest.TestCase):
def setUp(self):
self.random_state = np.random.RandomState(42)
self.fullLabels = self.random_state.randint(0, 4, 10)
self.availableLabelsNames = ["test_label_0", "test_label_1", "test_label_2", "test_label_3"]
def test_simple(self):
newLabels, newLabelsNames, usedIndices = GetMultiviewDb.getAllLabels(self.fullLabels, self.availableLabelsNames)
self.assertEqual(self.availableLabelsNames, newLabelsNames)
np.testing.assert_array_equal(usedIndices, np.arange(10))
np.testing.assert_array_equal(newLabels, self.fullLabels)
class Test_fillLabelNames(unittest.TestCase):
def setUp(self):
self.NB_CLASS = 2
self.askedLabelsNames = ["test_label_1", "test_label_3"]
self.randomState = np.random.RandomState(42)
self.availableLabelsNames = ["test_label_"+str(_) for _ in range(40)]
def test_simple(self):
askedLabelsNames, askedLabelsNamesSet = GetMultiviewDb.fillLabelNames(self.NB_CLASS,
self.askedLabelsNames,
self.randomState,
self.availableLabelsNames)
self.assertEqual(askedLabelsNames, self.askedLabelsNames)
self.assertEqual(askedLabelsNamesSet, set(self.askedLabelsNames))
def test_missing_labels_names(self):
self.NB_CLASS = 39
askedLabelsNames, askedLabelsNamesSet = GetMultiviewDb.fillLabelNames(self.NB_CLASS,
self.askedLabelsNames,
self.randomState,
self.availableLabelsNames)
self.assertEqual(askedLabelsNames, ['test_label_1', 'test_label_3', 'test_label_35', 'test_label_38', 'test_label_6', 'test_label_15', 'test_label_32', 'test_label_28', 'test_label_8', 'test_label_29', 'test_label_26', 'test_label_17', 'test_label_19', 'test_label_10', 'test_label_18', 'test_label_14', 'test_label_21', 'test_label_11', 'test_label_34', 'test_label_0', 'test_label_27', 'test_label_7', 'test_label_13', 'test_label_2', 'test_label_39', 'test_label_23', 'test_label_4', 'test_label_31', 'test_label_37', 'test_label_5', 'test_label_36', 'test_label_25', 'test_label_33', 'test_label_12', 'test_label_24', 'test_label_20', 'test_label_22', 'test_label_9', 'test_label_16'])
self.assertEqual(askedLabelsNamesSet, set(["test_label_"+str(_) for _ in range(30)]+["test_label_"+str(31+_) for _ in range(9)]))
def test_too_many_label_names(self):
self.NB_CLASS = 2
self.askedLabelsNames = ["test_label_1", "test_label_3", "test_label_4", "test_label_6"]
askedLabelsNames, askedLabelsNamesSet = GetMultiviewDb.fillLabelNames(self.NB_CLASS,
self.askedLabelsNames,
self.randomState,
self.availableLabelsNames)
self.assertEqual(askedLabelsNames, ["test_label_3", "test_label_6"])
self.assertEqual(askedLabelsNamesSet, set(["test_label_3", "test_label_6"]))
class Test_allAskedLabelsAreAvailable(unittest.TestCase):
def setUp(self):
self.askedLabelsNamesSet = {"test_label_1", "test_label_3"}
self.availableLabelsNames = ["test_label_0", "test_label_1", "test_label_2", "test_label_3"]
def test_asked_available_labels(self):
self.assertTrue(GetMultiviewDb.allAskedLabelsAreAvailable(self.askedLabelsNamesSet,self.availableLabelsNames))
def test_asked_unavailable_label(self):
self.askedLabelsNamesSet = {"test_label_1", "test_label_3", "chicken_is_heaven"}
self.assertFalse(GetMultiviewDb.allAskedLabelsAreAvailable(self.askedLabelsNamesSet,self.availableLabelsNames))
class Test_getClasses(unittest.TestCase):
def setUp(self):
self.random_state = np.random.RandomState(42)
def test_multiclass(self):
labelsSet = GetMultiviewDb.getClasses(self.random_state.randint(0,5,30))
self.assertEqual(labelsSet, {0,1,2,3,4})
def test_biclass(self):
labelsSet = GetMultiviewDb.getClasses(self.random_state.randint(0,2,30))
self.assertEqual(labelsSet, {0,1})
def test_one_class(self):
with self.assertRaises(GetMultiviewDb.DatasetError) as catcher:
GetMultiviewDb.getClasses(np.zeros(30,dtype=int))
exception = catcher.exception
self.assertTrue("Dataset must have at least two different labels" in exception)
class Test_getClassicDBhdf5(unittest.TestCase):
def setUp(self):
if not os.path.exists("Code/Tests/temp_tests"):
os.mkdir("Code/Tests/temp_tests")
self.dataset_file = h5py.File("Code/Tests/temp_tests/test_dataset.hdf5", "w")
self.pathF = "Code/Tests/temp_tests/"
self.nameDB = "test_dataset"
self.NB_CLASS = 2
self.askedLabelsNames = ["test_label_1", "test_label_3"]
self.random_state = np.random.RandomState(42)
self.views = ["test_view_1", "test_view_2"]
self.metadata_group = self.dataset_file.create_group("Metadata")
self.metadata_group.attrs["nbView"] = 4
self.labels_dataset = self.dataset_file.create_dataset("Labels", data=self.random_state.randint(0,4,10))
self.labels_dataset.attrs["names"] = ["test_label_0", "test_label_1", "test_label_2", "test_label_3"]
for i in range(4):
self.dataset = self.dataset_file.create_dataset("View" + str(i),
data=self.random_state.randint(0, 100, (10, 20)))
self.dataset.attrs["name"] = "test_view_" + str(i)
def test_simple(self):
dataset_file, labels_dictionnary = GetMultiviewDb.getClassicDBhdf5(self.views, self.pathF, self.nameDB,
self.NB_CLASS, self.askedLabelsNames,
self.random_state)
def tearDown(self):
os.remove("Code/Tests/temp_tests/test_dataset_temp_view_label_select.hdf5")
os.remove("Code/Tests/temp_tests/test_dataset.hdf5")
dirs = os.listdir("Code/Tests/temp_tests")
for dir in dirs:
print(dir)
os.rmdir("Code/Tests/temp_tests")
\ No newline at end of file
......@@ -21,6 +21,12 @@ class Test_parseTheArgs(unittest.TestCase):
class Test_initRandomState(unittest.TestCase):
def setUp(self):
os.mkdir("Code/Tests/temp_tests/")
def tearDown(self):
os.rmdir("Code/Tests/temp_tests/")
def test_random_state_42(self):
randomState_42 = np.random.RandomState(42)
randomState = execution.initRandomState("42", "Code/Tests/temp_tests/")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment