Skip to content
Snippets Groups Projects
Commit cdd82b14 authored by Luc Giffon's avatar Luc Giffon
Browse files

test des transformateurs sans modeles

todo: gerer les transformateurs deprecated
parent d912fe1a
No related branches found
No related tags found
No related merge requests found
import os
import numpy as np
from sklearn.cross_validation import train_test_split
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.preprocessing import LabelBinarizer
......
......@@ -46,7 +46,7 @@ class ImageDataset(Dataset):
# one need to apply it to the data
# then to save the transformation
logger.debug("Files {} don't exist or model md5 checksum doesn't match. Need to produce them".format(transform_filepaths))
logger.info("Apply convolution of {} to dataset {}".format(transformer_name, self.s_name))
logger.info("Apply transformation of {} to dataset {}".format(transformer_name, self.s_name))
for kw in self.data_groups_private:
data, labels = getattr(self, kw)
transformed_data, transformed_labels = transformer.transform(data, labels)
......
import os
import tensorflow as tf
import numpy as np
......@@ -7,13 +6,13 @@ from skluc.utils import logger, Singleton
class RescaleTransformer(Transformer, metaclass=Singleton):
def __init__(self, scaling_factor):
self.rescale_factor = scaling_factor
self.__name = os.path.join("resize", "{}".format(str(scaling_factor).replace(".", "-")))
def __init__(self, data_name, scaling_factor):
transformation_name = self.__class__.__name__ + "_" + "{}".format(str(scaling_factor).replace(".", "-"))
super().__init__(data_name=data_name,
transformation_name=transformation_name)
@property
def name(self):
return self.__name
self.rescale_factor = scaling_factor
def transform(self, data, labels):
if len(data.shape) != 4:
......@@ -25,10 +24,12 @@ class RescaleTransformer(Transformer, metaclass=Singleton):
sess = tf.InteractiveSession()
images_mat = data
output_shape = np.multiply(images_mat.shape[1:-1], (self.rescale_factor, self.rescale_factor))
float_output_shape = np.multiply(images_mat.shape[1:-1], (self.rescale_factor, self.rescale_factor))
output_shape = float_output_shape.astype(np.int)
labels = labels
logger.debug("Expected output shape: {}".format((data.shape[0], *output_shape, data.shape[-1])))
new_images = tf.image.resize_images(images_mat, output_shape).eval()
logger.debug("Shape of data after rescaling: {}".format(new_images.shape))
sess.close()
tf.reset_default_graph()
return np.array(new_images), labels
\ No newline at end of file
import os
import tensorflow as tf
import numpy as np
......@@ -7,15 +6,15 @@ from skluc.utils import logger, Singleton
class ResizeTransformer(Transformer, metaclass=Singleton):
def __init__(self, output_shape):
def __init__(self, data_name, output_shape):
if len(output_shape) != 2:
raise AssertionError("Output shape should be 2D and it is {}D: {}".format(len(output_shape), output_shape))
self.output_shape = output_shape
self.__name = os.path.join("resize", "{}x{}".format(output_shape[0], output_shape[1]))
@property
def name(self):
return self.__name
transformation_name = self.__class__.__name__ + "_" + "{}x{}".format(output_shape[0], output_shape[1])
super().__init__(data_name=data_name,
transformation_name=transformation_name)
self.output_shape = output_shape
def transform(self, data, labels):
if len(data.shape) != 4:
......@@ -29,10 +28,10 @@ class ResizeTransformer(Transformer, metaclass=Singleton):
sess = tf.InteractiveSession()
images_mat = data
labels = labels
lst_new_image = []
for image_mat in images_mat:
new_image = tf.image.resize_images(image_mat, self.output_shape).eval()
lst_new_image.append(new_image)
logger.debug("Shape data after resize: {}".format(np.array(lst_new_image).shape))
new_images = tf.image.resize_images(images_mat, self.output_shape).eval()
logger.debug("Shape data after resize: {}".format(new_images.shape))
sess.close()
return np.array(lst_new_image), labels
tf.reset_default_graph()
return np.array(new_images), labels
import unittest
from skluc.data.mldatasets import MnistDataset, Cifar10Dataset, Cifar100FineDataset, SVHNDataset
from skluc.data.transformation.RescaleTransformer import RescaleTransformer
from skluc.utils import logger
class TestResizeTransformer(unittest.TestCase):
def setUp(self):
self.dict_datasets = {
"mnist": MnistDataset,
"cifar10": Cifar10Dataset,
"cifar100": Cifar100FineDataset,
"svhn": SVHNDataset
}
self.lst_scales = [
0.5,
0.7,
1,
2
]
def test_transform(self):
valsize = 10000
for data_name in self.dict_datasets:
logger.info("Testing dataset {}".format(data_name))
for scale in self.lst_scales:
logger.info("Testing size {}".format(str(scale)))
dataset = self.dict_datasets[data_name]
d = dataset(validation_size=valsize)
d.load()
d.flatten()
d.to_image()
trans = RescaleTransformer(data_name=data_name, scaling_factor=scale)
d.apply_transformer(transformer=trans)
del trans
def test_init(self):
for data_name in self.dict_datasets:
for scale in self.lst_scales:
logger.info("Testing size {}".format(str(scale)))
trans = RescaleTransformer(data_name=data_name, scaling_factor=scale)
del trans
if __name__ == '__main__':
unittest.main()
import unittest
from skluc.data.mldatasets import MnistDataset, Cifar10Dataset, Cifar100FineDataset, SVHNDataset
from skluc.data.transformation.ResizeTransformer import ResizeTransformer
from skluc.utils import logger
class TestResizeTransformer(unittest.TestCase):
def setUp(self):
self.dict_datasets = {
"mnist": MnistDataset,
"cifar10": Cifar10Dataset,
"cifar100": Cifar100FineDataset,
"svhn": SVHNDataset
}
self.lst_sizes = [
(28, 32),
(32, 32),
(28, 28),
(32, 28)
]
def test_transform(self):
valsize = 10000
for data_name in self.dict_datasets:
logger.info("Testing dataset {}".format(data_name))
for size in self.lst_sizes:
logger.info("Testing size {}".format(str(size)))
dataset = self.dict_datasets[data_name]
d = dataset(validation_size=valsize)
d.load()
d.flatten()
d.to_image()
trans = ResizeTransformer(data_name=data_name, output_shape=size)
d.apply_transformer(transformer=trans)
del trans
def test_init(self):
for data_name in self.dict_datasets:
for size in self.lst_sizes:
logger.info("Testing size {}".format(str(size)))
trans = ResizeTransformer(data_name=data_name, output_shape=size)
del trans
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment