From 002851b1f07146b31272dc13bc4297d99049bc2b Mon Sep 17 00:00:00 2001
From: Luc Giffon <luc.giffon@lis-lab.fr>
Date: Sun, 2 Sep 2018 12:18:08 +0200
Subject: [PATCH] tests for transformation are now working

---
 skluc/main/data/mldatasets/Cifar100FineDataset.py    |  2 +-
 skluc/main/data/mldatasets/Cifar10Dataset.py         |  2 +-
 skluc/main/data/mldatasets/ImageDataset.py           |  2 +-
 skluc/main/data/mldatasets/MnistDataset.py           |  2 +-
 skluc/main/data/mldatasets/SVHNDataset.py            |  2 +-
 skluc/main/data/mldatasets/__init__.py               | 12 ++++++------
 .../data/transformation/KerasModelTransformer.py     |  6 ++----
 skluc/main/data/transformation/LeCunTransformer.py   | 10 +++++-----
 skluc/main/data/transformation/RescaleTransformer.py |  2 +-
 skluc/main/data/transformation/ResizeTransformer.py  |  2 +-
 .../data/transformation/VGG19Transformer/__init__.py | 12 ++++++------
 11 files changed, 26 insertions(+), 28 deletions(-)

diff --git a/skluc/main/data/mldatasets/Cifar100FineDataset.py b/skluc/main/data/mldatasets/Cifar100FineDataset.py
index ea0b98b..6bb7ff8 100644
--- a/skluc/main/data/mldatasets/Cifar100FineDataset.py
+++ b/skluc/main/data/mldatasets/Cifar100FineDataset.py
@@ -4,8 +4,8 @@ import tarfile
 
 import numpy as np
 
+from skluc.main.data.mldatasets.ImageDataset import ImageDataset
 from skluc.main.utils import LabeledData
-from skluc.main.data.mldatasets import ImageDataset
 from skluc.main.utils import logger, check_files
 
 
diff --git a/skluc/main/data/mldatasets/Cifar10Dataset.py b/skluc/main/data/mldatasets/Cifar10Dataset.py
index 38dc6f5..ffe2222 100644
--- a/skluc/main/data/mldatasets/Cifar10Dataset.py
+++ b/skluc/main/data/mldatasets/Cifar10Dataset.py
@@ -4,8 +4,8 @@ import tarfile
 
 import numpy as np
 
+from skluc.main.data.mldatasets.ImageDataset import ImageDataset
 from skluc.main.utils import LabeledData
-from skluc.main.data.mldatasets import ImageDataset
 from skluc.main.utils import logger, check_files
 import matplotlib.pyplot as plt
 
diff --git a/skluc/main/data/mldatasets/ImageDataset.py b/skluc/main/data/mldatasets/ImageDataset.py
index 1713e79..4151cd3 100644
--- a/skluc/main/data/mldatasets/ImageDataset.py
+++ b/skluc/main/data/mldatasets/ImageDataset.py
@@ -2,8 +2,8 @@ import os
 
 import numpy as np
 
+from skluc.main.data.mldatasets.Dataset import Dataset
 from skluc.main.utils import LabeledData
-from skluc.main.data.mldatasets import Dataset
 from skluc.main.utils import logger, create_directory, check_files
 
 
diff --git a/skluc/main/data/mldatasets/MnistDataset.py b/skluc/main/data/mldatasets/MnistDataset.py
index d39c223..f2adb8f 100644
--- a/skluc/main/data/mldatasets/MnistDataset.py
+++ b/skluc/main/data/mldatasets/MnistDataset.py
@@ -4,8 +4,8 @@ import struct
 
 import numpy as np
 
+from skluc.main.data.mldatasets.ImageDataset import ImageDataset
 from skluc.main.utils import LabeledData
-from skluc.main.data.mldatasets import ImageDataset
 from skluc.main.utils import logger
 
 
diff --git a/skluc/main/data/mldatasets/SVHNDataset.py b/skluc/main/data/mldatasets/SVHNDataset.py
index 4f4ebfe..45f7739 100644
--- a/skluc/main/data/mldatasets/SVHNDataset.py
+++ b/skluc/main/data/mldatasets/SVHNDataset.py
@@ -3,8 +3,8 @@ import os
 import numpy as np
 import scipy.io as sio
 
+from skluc.main.data.mldatasets.ImageDataset import ImageDataset
 from skluc.main.utils import LabeledData
-from skluc.main.data.mldatasets import ImageDataset
 from skluc.main.utils import logger
 
 
diff --git a/skluc/main/data/mldatasets/__init__.py b/skluc/main/data/mldatasets/__init__.py
index c747820..f65832a 100644
--- a/skluc/main/data/mldatasets/__init__.py
+++ b/skluc/main/data/mldatasets/__init__.py
@@ -10,12 +10,12 @@ The currently implemented datasets are:
 """
 
 
-from skluc.main.data.mldatasets import Cifar100FineDataset
-from skluc.main.data.mldatasets import Cifar10Dataset
-from skluc.main.data.mldatasets import MnistDataset
-from skluc.main.data.mldatasets import MovieReviewV1Dataset
-from skluc.main.data.mldatasets import OmniglotDataset
-from skluc.main.data.mldatasets import SVHNDataset
+from skluc.main.data.mldatasets.Cifar100FineDataset import Cifar100FineDataset
+from skluc.main.data.mldatasets.Cifar10Dataset import Cifar10Dataset
+from skluc.main.data.mldatasets.MnistDataset import MnistDataset
+from skluc.main.data.mldatasets.MovieReviewDataset import MovieReviewV1Dataset
+from skluc.main.data.mldatasets.OmniglotDataset import OmniglotDataset
+from skluc.main.data.mldatasets.SVHNDataset import SVHNDataset
 
 
 __all__ = ["Cifar10Dataset", "Cifar100FineDataset", "MnistDataset", "OmniglotDataset", "MovieReviewV1Dataset", "SVHNDataset"]
diff --git a/skluc/main/data/transformation/KerasModelTransformer.py b/skluc/main/data/transformation/KerasModelTransformer.py
index c689077..1aa35ea 100644
--- a/skluc/main/data/transformation/KerasModelTransformer.py
+++ b/skluc/main/data/transformation/KerasModelTransformer.py
@@ -1,8 +1,7 @@
 import os
 import numpy as np
-from keras import Model
 
-from skluc.main.data import Transformer
+from skluc.main.data.transformation.Transformer import Transformer
 from skluc.main.utils import check_file_md5, logger
 
 
@@ -33,11 +32,10 @@ class KerasModelTransformer(Transformer):
             raise AssertionError("Data shape should be of size 4 (image batch with channel dimension). "
                                  "It is {}: {}. Maybe have you forgotten to reshape it to an image format?"
                                  "".format(len(data.shape), data.shape))
-        model = Model(inputs=self.keras_model.input, outputs=self.keras_model.output)
         logger.debug("Type of data to transform: {}".format(type(data)))
         logger.debug("Length of data to transform: {}".format(len(data)))
         logger.debug("Transforming data using pretrained model")
-        transformed_data = np.array(model.predict(data)).reshape(-1, *model.output_shape[1:])
+        transformed_data = np.array(self.keras_model.predict(data)).reshape(-1, *self.keras_model.output_shape[1:])
         logger.debug("Type of transformed data: {}".format(type(transformed_data)))
         return transformed_data, labels
 
diff --git a/skluc/main/data/transformation/LeCunTransformer.py b/skluc/main/data/transformation/LeCunTransformer.py
index e8903c4..92e0ddf 100644
--- a/skluc/main/data/transformation/LeCunTransformer.py
+++ b/skluc/main/data/transformation/LeCunTransformer.py
@@ -2,7 +2,7 @@ from keras.models import load_model
 
 from keras import Model
 
-from skluc.main.data import KerasModelTransformer
+from skluc.main.data.transformation.KerasModelTransformer import KerasModelTransformer
 from skluc.main.utils import logger, create_directory, download_data, check_file_md5, DownloadableModel
 
 
@@ -22,20 +22,20 @@ class LecunTransformer(KerasModelTransformer):
             raise ValueError("Unknown data name. Can't load weights")
         transformation_name = self.__class__.__name__
 
+        self.keras_model = None
+
         super().__init__(data_name=data_name,
                          transformation_name=transformation_name)
 
-        self.keras_model = None
-
     def load(self):
         create_directory(self.s_download_dir)
         s_model_path = download_data(self.MAP_DATA_MODEL[self.data_name].url, self.s_download_dir)
         check_file_md5(s_model_path, self.MAP_DATA_MODEL[self.data_name].checksum)
         if self.keras_model is None:
-            logger.debug("Loading VGG19 model with cifar10 weights")
+            logger.debug("Loading Lecun model with {} weights".format(self.data_name))
             self.keras_model = load_model(s_model_path)
 
             self.keras_model = Model(inputs=self.keras_model.input,
                                      outputs=self.keras_model.get_layer('conv_pool_2').output)
         else:
-            logger.debug("Skip loading model Lecun model with mnist weights. Already there.")
+            logger.debug("Skip loading model Lecun model with {} weights. Already there.".format(self.data_name))
diff --git a/skluc/main/data/transformation/RescaleTransformer.py b/skluc/main/data/transformation/RescaleTransformer.py
index 6d89da9..2a9b4b9 100644
--- a/skluc/main/data/transformation/RescaleTransformer.py
+++ b/skluc/main/data/transformation/RescaleTransformer.py
@@ -1,7 +1,7 @@
 import tensorflow as tf
 import numpy as np
 
-from skluc.main.data import Transformer
+from skluc.main.data.transformation.Transformer import Transformer
 from skluc.main.utils import logger, Singleton
 
 
diff --git a/skluc/main/data/transformation/ResizeTransformer.py b/skluc/main/data/transformation/ResizeTransformer.py
index 867818e..2f5161f 100644
--- a/skluc/main/data/transformation/ResizeTransformer.py
+++ b/skluc/main/data/transformation/ResizeTransformer.py
@@ -1,7 +1,7 @@
 import tensorflow as tf
 import numpy as np
 
-from skluc.main.data import Transformer
+from skluc.main.data.transformation.Transformer import Transformer
 from skluc.main.utils import logger, Singleton
 
 
diff --git a/skluc/main/data/transformation/VGG19Transformer/__init__.py b/skluc/main/data/transformation/VGG19Transformer/__init__.py
index 5f93e21..9af31f4 100644
--- a/skluc/main/data/transformation/VGG19Transformer/__init__.py
+++ b/skluc/main/data/transformation/VGG19Transformer/__init__.py
@@ -1,8 +1,8 @@
 from keras import Model
 from keras.models import load_model
 
-from skluc.main.data.mldatasets import Cifar10Dataset
-from skluc.main.data import KerasModelTransformer
+from skluc.main.data.mldatasets.Cifar10Dataset import Cifar10Dataset
+from skluc.main.data.transformation.KerasModelTransformer import KerasModelTransformer
 from skluc.main.utils import logger, create_directory, download_data, check_file_md5, Singleton, DownloadableModel
 
 
@@ -42,17 +42,17 @@ class VGG19Transformer(KerasModelTransformer, metaclass=Singleton):
         self.__cut_layer_name = cut_layer_name
         self.__cut_layer_index = cut_layer_index
 
+        self.keras_model = None
+
         super().__init__(data_name=data_name,
                          transformation_name=transformation_name)
 
-        self.keras_model = None
-
     def load(self):
         create_directory(self.s_download_dir)
         s_model_path = download_data(self.MAP_DATA_MODEL[self.data_name].url, self.s_download_dir)
         check_file_md5(s_model_path, self.MAP_DATA_MODEL[self.data_name].checksum)
         if self.keras_model is None:
-            logger.debug("Loading VGG19 model for {} transformation".format(self.transformation_name))
+            logger.debug("Loading VGG19 model for {} transformation with {} weights".format(self.transformation_name, self.data_name))
             self.keras_model = load_model(s_model_path)
 
             if self.__cut_layer_name is not None:
@@ -63,7 +63,7 @@ class VGG19Transformer(KerasModelTransformer, metaclass=Singleton):
                                             outputs=self.keras_model.get_layer(name=self.__cut_layer_index).output)
 
         else:
-            logger.debug("Skip loading model VGG19 for {} transformation. Already there.".format(self.transformation_name))
+            logger.debug("Skip loading model VGG19 for {} transformation with {} weights. Already there.".format(self.transformation_name, self.data_name))
 
 
 if __name__ == '__main__':
-- 
GitLab