From 002851b1f07146b31272dc13bc4297d99049bc2b Mon Sep 17 00:00:00 2001 From: Luc Giffon <luc.giffon@lis-lab.fr> Date: Sun, 2 Sep 2018 12:18:08 +0200 Subject: [PATCH] tests for transformation are now working --- skluc/main/data/mldatasets/Cifar100FineDataset.py | 2 +- skluc/main/data/mldatasets/Cifar10Dataset.py | 2 +- skluc/main/data/mldatasets/ImageDataset.py | 2 +- skluc/main/data/mldatasets/MnistDataset.py | 2 +- skluc/main/data/mldatasets/SVHNDataset.py | 2 +- skluc/main/data/mldatasets/__init__.py | 12 ++++++------ .../data/transformation/KerasModelTransformer.py | 6 ++---- skluc/main/data/transformation/LeCunTransformer.py | 10 +++++----- skluc/main/data/transformation/RescaleTransformer.py | 2 +- skluc/main/data/transformation/ResizeTransformer.py | 2 +- .../data/transformation/VGG19Transformer/__init__.py | 12 ++++++------ 11 files changed, 26 insertions(+), 28 deletions(-) diff --git a/skluc/main/data/mldatasets/Cifar100FineDataset.py b/skluc/main/data/mldatasets/Cifar100FineDataset.py index ea0b98b..6bb7ff8 100644 --- a/skluc/main/data/mldatasets/Cifar100FineDataset.py +++ b/skluc/main/data/mldatasets/Cifar100FineDataset.py @@ -4,8 +4,8 @@ import tarfile import numpy as np +from skluc.main.data.mldatasets.ImageDataset import ImageDataset from skluc.main.utils import LabeledData -from skluc.main.data.mldatasets import ImageDataset from skluc.main.utils import logger, check_files diff --git a/skluc/main/data/mldatasets/Cifar10Dataset.py b/skluc/main/data/mldatasets/Cifar10Dataset.py index 38dc6f5..ffe2222 100644 --- a/skluc/main/data/mldatasets/Cifar10Dataset.py +++ b/skluc/main/data/mldatasets/Cifar10Dataset.py @@ -4,8 +4,8 @@ import tarfile import numpy as np +from skluc.main.data.mldatasets.ImageDataset import ImageDataset from skluc.main.utils import LabeledData -from skluc.main.data.mldatasets import ImageDataset from skluc.main.utils import logger, check_files import matplotlib.pyplot as plt diff --git a/skluc/main/data/mldatasets/ImageDataset.py b/skluc/main/data/mldatasets/ImageDataset.py index 1713e79..4151cd3 100644 --- a/skluc/main/data/mldatasets/ImageDataset.py +++ b/skluc/main/data/mldatasets/ImageDataset.py @@ -2,8 +2,8 @@ import os import numpy as np +from skluc.main.data.mldatasets.Dataset import Dataset from skluc.main.utils import LabeledData -from skluc.main.data.mldatasets import Dataset from skluc.main.utils import logger, create_directory, check_files diff --git a/skluc/main/data/mldatasets/MnistDataset.py b/skluc/main/data/mldatasets/MnistDataset.py index d39c223..f2adb8f 100644 --- a/skluc/main/data/mldatasets/MnistDataset.py +++ b/skluc/main/data/mldatasets/MnistDataset.py @@ -4,8 +4,8 @@ import struct import numpy as np +from skluc.main.data.mldatasets.ImageDataset import ImageDataset from skluc.main.utils import LabeledData -from skluc.main.data.mldatasets import ImageDataset from skluc.main.utils import logger diff --git a/skluc/main/data/mldatasets/SVHNDataset.py b/skluc/main/data/mldatasets/SVHNDataset.py index 4f4ebfe..45f7739 100644 --- a/skluc/main/data/mldatasets/SVHNDataset.py +++ b/skluc/main/data/mldatasets/SVHNDataset.py @@ -3,8 +3,8 @@ import os import numpy as np import scipy.io as sio +from skluc.main.data.mldatasets.ImageDataset import ImageDataset from skluc.main.utils import LabeledData -from skluc.main.data.mldatasets import ImageDataset from skluc.main.utils import logger diff --git a/skluc/main/data/mldatasets/__init__.py b/skluc/main/data/mldatasets/__init__.py index c747820..f65832a 100644 --- a/skluc/main/data/mldatasets/__init__.py +++ b/skluc/main/data/mldatasets/__init__.py @@ -10,12 +10,12 @@ The currently implemented datasets are: """ -from skluc.main.data.mldatasets import Cifar100FineDataset -from skluc.main.data.mldatasets import Cifar10Dataset -from skluc.main.data.mldatasets import MnistDataset -from skluc.main.data.mldatasets import MovieReviewV1Dataset -from skluc.main.data.mldatasets import OmniglotDataset -from skluc.main.data.mldatasets import SVHNDataset +from skluc.main.data.mldatasets.Cifar100FineDataset import Cifar100FineDataset +from skluc.main.data.mldatasets.Cifar10Dataset import Cifar10Dataset +from skluc.main.data.mldatasets.MnistDataset import MnistDataset +from skluc.main.data.mldatasets.MovieReviewDataset import MovieReviewV1Dataset +from skluc.main.data.mldatasets.OmniglotDataset import OmniglotDataset +from skluc.main.data.mldatasets.SVHNDataset import SVHNDataset __all__ = ["Cifar10Dataset", "Cifar100FineDataset", "MnistDataset", "OmniglotDataset", "MovieReviewV1Dataset", "SVHNDataset"] diff --git a/skluc/main/data/transformation/KerasModelTransformer.py b/skluc/main/data/transformation/KerasModelTransformer.py index c689077..1aa35ea 100644 --- a/skluc/main/data/transformation/KerasModelTransformer.py +++ b/skluc/main/data/transformation/KerasModelTransformer.py @@ -1,8 +1,7 @@ import os import numpy as np -from keras import Model -from skluc.main.data import Transformer +from skluc.main.data.transformation.Transformer import Transformer from skluc.main.utils import check_file_md5, logger @@ -33,11 +32,10 @@ class KerasModelTransformer(Transformer): raise AssertionError("Data shape should be of size 4 (image batch with channel dimension). " "It is {}: {}. Maybe have you forgotten to reshape it to an image format?" "".format(len(data.shape), data.shape)) - model = Model(inputs=self.keras_model.input, outputs=self.keras_model.output) logger.debug("Type of data to transform: {}".format(type(data))) logger.debug("Length of data to transform: {}".format(len(data))) logger.debug("Transforming data using pretrained model") - transformed_data = np.array(model.predict(data)).reshape(-1, *model.output_shape[1:]) + transformed_data = np.array(self.keras_model.predict(data)).reshape(-1, *self.keras_model.output_shape[1:]) logger.debug("Type of transformed data: {}".format(type(transformed_data))) return transformed_data, labels diff --git a/skluc/main/data/transformation/LeCunTransformer.py b/skluc/main/data/transformation/LeCunTransformer.py index e8903c4..92e0ddf 100644 --- a/skluc/main/data/transformation/LeCunTransformer.py +++ b/skluc/main/data/transformation/LeCunTransformer.py @@ -2,7 +2,7 @@ from keras.models import load_model from keras import Model -from skluc.main.data import KerasModelTransformer +from skluc.main.data.transformation.KerasModelTransformer import KerasModelTransformer from skluc.main.utils import logger, create_directory, download_data, check_file_md5, DownloadableModel @@ -22,20 +22,20 @@ class LecunTransformer(KerasModelTransformer): raise ValueError("Unknown data name. Can't load weights") transformation_name = self.__class__.__name__ + self.keras_model = None + super().__init__(data_name=data_name, transformation_name=transformation_name) - self.keras_model = None - def load(self): create_directory(self.s_download_dir) s_model_path = download_data(self.MAP_DATA_MODEL[self.data_name].url, self.s_download_dir) check_file_md5(s_model_path, self.MAP_DATA_MODEL[self.data_name].checksum) if self.keras_model is None: - logger.debug("Loading VGG19 model with cifar10 weights") + logger.debug("Loading Lecun model with {} weights".format(self.data_name)) self.keras_model = load_model(s_model_path) self.keras_model = Model(inputs=self.keras_model.input, outputs=self.keras_model.get_layer('conv_pool_2').output) else: - logger.debug("Skip loading model Lecun model with mnist weights. Already there.") + logger.debug("Skip loading model Lecun model with {} weights. Already there.".format(self.data_name)) diff --git a/skluc/main/data/transformation/RescaleTransformer.py b/skluc/main/data/transformation/RescaleTransformer.py index 6d89da9..2a9b4b9 100644 --- a/skluc/main/data/transformation/RescaleTransformer.py +++ b/skluc/main/data/transformation/RescaleTransformer.py @@ -1,7 +1,7 @@ import tensorflow as tf import numpy as np -from skluc.main.data import Transformer +from skluc.main.data.transformation.Transformer import Transformer from skluc.main.utils import logger, Singleton diff --git a/skluc/main/data/transformation/ResizeTransformer.py b/skluc/main/data/transformation/ResizeTransformer.py index 867818e..2f5161f 100644 --- a/skluc/main/data/transformation/ResizeTransformer.py +++ b/skluc/main/data/transformation/ResizeTransformer.py @@ -1,7 +1,7 @@ import tensorflow as tf import numpy as np -from skluc.main.data import Transformer +from skluc.main.data.transformation.Transformer import Transformer from skluc.main.utils import logger, Singleton diff --git a/skluc/main/data/transformation/VGG19Transformer/__init__.py b/skluc/main/data/transformation/VGG19Transformer/__init__.py index 5f93e21..9af31f4 100644 --- a/skluc/main/data/transformation/VGG19Transformer/__init__.py +++ b/skluc/main/data/transformation/VGG19Transformer/__init__.py @@ -1,8 +1,8 @@ from keras import Model from keras.models import load_model -from skluc.main.data.mldatasets import Cifar10Dataset -from skluc.main.data import KerasModelTransformer +from skluc.main.data.mldatasets.Cifar10Dataset import Cifar10Dataset +from skluc.main.data.transformation.KerasModelTransformer import KerasModelTransformer from skluc.main.utils import logger, create_directory, download_data, check_file_md5, Singleton, DownloadableModel @@ -42,17 +42,17 @@ class VGG19Transformer(KerasModelTransformer, metaclass=Singleton): self.__cut_layer_name = cut_layer_name self.__cut_layer_index = cut_layer_index + self.keras_model = None + super().__init__(data_name=data_name, transformation_name=transformation_name) - self.keras_model = None - def load(self): create_directory(self.s_download_dir) s_model_path = download_data(self.MAP_DATA_MODEL[self.data_name].url, self.s_download_dir) check_file_md5(s_model_path, self.MAP_DATA_MODEL[self.data_name].checksum) if self.keras_model is None: - logger.debug("Loading VGG19 model for {} transformation".format(self.transformation_name)) + logger.debug("Loading VGG19 model for {} transformation with {} weights".format(self.transformation_name, self.data_name)) self.keras_model = load_model(s_model_path) if self.__cut_layer_name is not None: @@ -63,7 +63,7 @@ class VGG19Transformer(KerasModelTransformer, metaclass=Singleton): outputs=self.keras_model.get_layer(name=self.__cut_layer_index).output) else: - logger.debug("Skip loading model VGG19 for {} transformation. Already there.".format(self.transformation_name)) + logger.debug("Skip loading model VGG19 for {} transformation with {} weights. Already there.".format(self.transformation_name, self.data_name)) if __name__ == '__main__': -- GitLab