Skip to content
Snippets Groups Projects
Commit 6344c6f0 authored by Luc Giffon's avatar Luc Giffon
Browse files

solve normalization problem, add min and max property to dataset

parent 16c2c1cd
No related branches found
No related tags found
No related merge requests found
...@@ -34,6 +34,14 @@ class Dataset(object): ...@@ -34,6 +34,14 @@ class Dataset(object):
self.permuted_index_validation = None self.permuted_index_validation = None
self.validation_size = validation_size self.validation_size = validation_size
@property
def min(self):
return np.min(self.train.data)
@property
def max(self):
return np.max(self.train.data)
def reduce_data_size(self, new_size): def reduce_data_size(self, new_size):
logger.info("Reducing datasize of dataset {} to .".format(self.s_name, new_size)) logger.info("Reducing datasize of dataset {} to .".format(self.s_name, new_size))
kept_indices = self.get_uniform_class_rand_indices_train(new_size) kept_indices = self.get_uniform_class_rand_indices_train(new_size)
...@@ -217,8 +225,9 @@ class Dataset(object): ...@@ -217,8 +225,9 @@ class Dataset(object):
if len(datlab.labels) == 0: if len(datlab.labels) == 0:
continue continue
data = datlab.data data = datlab.data
_min = data.min() _min = self.min
_max = data.max() _max = self.max
logger.debug(f"Minimum value of train set is {_min}; max is {_max}")
data = (data - _min) / (_max - _min) data = (data - _min) / (_max - _min)
logger.debug("Apply normalization to {} data of {} dataset.".format(kw, self.s_name)) logger.debug("Apply normalization to {} data of {} dataset.".format(kw, self.s_name))
setattr(self, kw, LabeledData(data, datlab.labels)) setattr(self, kw, LabeledData(data, datlab.labels))
......
...@@ -24,6 +24,20 @@ class TestDataset(unittest.TestCase): ...@@ -24,6 +24,20 @@ class TestDataset(unittest.TestCase):
def setUp(self): def setUp(self):
self.dataset_classes = [FooDataset] self.dataset_classes = [FooDataset]
def test_min_max(self):
for d_class in self.dataset_classes:
d1 = d_class(validation_size=1000, seed=0)
d1.load()
mini_train = np.min(d1.train.data)
maxi_train = np.max(d1.train.data)
self.assertEqual(mini_train, d1.min)
self.assertEqual(maxi_train, d1.max)
mini_test = np.min(d1.test.data)
maxi_test = np.max(d1.test.data)
self.assertNotEquals(mini_test, d1.min)
self.assertNotEquals(maxi_test, d1.max)
def test_seed_train_val(self): def test_seed_train_val(self):
for d_class in self.dataset_classes: for d_class in self.dataset_classes:
d1 = d_class(validation_size=1000, seed=0) d1 = d_class(validation_size=1000, seed=0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment