From c4fd292c6edc6cfc6f085b4fffe1caa82ef76669 Mon Sep 17 00:00:00 2001
From: Luc Giffon <luc.giffon@lis-lab.fr>
Date: Thu, 29 Nov 2018 18:30:56 +0100
Subject: [PATCH] Fix issues deepfriedconvnet and mnist

No more horizontal flip for mnist and svhn datasets (horizontal flips on number = not good)
Now gamma is computed for deepfriedconvnet (when it was only done for deepstrom with rbf in the former version)
---
 .../deepstrom_classif_end_to_end.py           | 38 ++++++++++++-------
 1 file changed, 25 insertions(+), 13 deletions(-)

diff --git a/main/experiments/scripts/november_2018/end_to_end_with_augment/solve_mnist_and_deepfried_issues/deepstrom_classif_end_to_end.py b/main/experiments/scripts/november_2018/end_to_end_with_augment/solve_mnist_and_deepfried_issues/deepstrom_classif_end_to_end.py
index 903d054..49f5953 100644
--- a/main/experiments/scripts/november_2018/end_to_end_with_augment/solve_mnist_and_deepfried_issues/deepstrom_classif_end_to_end.py
+++ b/main/experiments/scripts/november_2018/end_to_end_with_augment/solve_mnist_and_deepfried_issues/deepstrom_classif_end_to_end.py
@@ -54,21 +54,19 @@ Kernel related:
 
 """
 
-import skluc.main.data.mldatasets as dataset
+import docopt
 import numpy as np
 import tensorflow as tf
+import time as t
 from tensorflow.python.keras.layers import Dense
-from tensorflow.python.keras.regularizers import l2
-from tensorflow.python.keras.initializers import he_normal
 from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
 
+import skluc.main.data.mldatasets as dataset
 from skluc.main.tensorflow_.kernel_approximation.fastfood_layer import FastFoodLayer
 from skluc.main.tensorflow_.kernel_approximation.nystrom_layer import DeepstromLayerEndToEnd
 from skluc.main.tensorflow_.models import build_lenet_model, build_vgg19_model
-from skluc.main.utils import logger, memory_usage, ParameterManager, ResultManager, ResultPrinter
 from skluc.main.tensorflow_.utils import batch_generator
-import time as t
-import docopt
+from skluc.main.utils import logger, memory_usage, ParameterManager, ResultManager, ResultPrinter
 
 
 class ParameterManagerMain(ParameterManager):
@@ -101,7 +99,7 @@ class ParameterManagerMain(ParameterManager):
             return None
 
     def init_kernel_dict(self, data):
-        if self["kernel"] == "rbf":
+        if self["kernel"] == "rbf" or self["network"] == "deepfriedconvnet":
             GAMMA = self.get_gamma_value(data)
             self["--gamma"] = GAMMA
             self.__kernel_dict = {"gamma": GAMMA}
@@ -136,15 +134,35 @@ def main(paraman, resman, printman):
     if paraman["dataset"] == "mnist":
         data = dataset.MnistDataset(validation_size=paraman["--validation-size"], seed=paraman["--seed"])
         convmodel_func = build_lenet_model
+        datagen = ImageDataGenerator(
+            rotation_range=20,
+            width_shift_range=0.2,
+            height_shift_range=0.2,
+            horizontal_flip=False)
     elif paraman["dataset"] == "cifar10":
         data = dataset.Cifar10Dataset(validation_size=paraman["--validation-size"], seed=paraman["--seed"])
         convmodel_func = build_vgg19_model
+        datagen = ImageDataGenerator(
+            rotation_range=20,
+            width_shift_range=0.2,
+            height_shift_range=0.2,
+            horizontal_flip=True)
     elif paraman["dataset"] == "cifar100":
         data = dataset.Cifar100FineDataset(validation_size=paraman["--validation-size"], seed=paraman["--seed"])
         convmodel_func = build_vgg19_model
+        datagen = ImageDataGenerator(
+            rotation_range=20,
+            width_shift_range=0.2,
+            height_shift_range=0.2,
+            horizontal_flip=True)
     elif paraman["dataset"] == "svhn":
         data = dataset.SVHNDataset(validation_size=paraman["--validation-size"], seed=paraman["--seed"])
         convmodel_func = build_vgg19_model
+        datagen = ImageDataGenerator(
+            rotation_range=20,
+            width_shift_range=0.2,
+            height_shift_range=0.2,
+            horizontal_flip=False)
     else:
         raise ValueError("Unknown dataset")
 
@@ -159,12 +177,6 @@ def main(paraman, resman, printman):
     X_train, y_train = data.train.data, data.train.labels
     X_test, y_test = data.test.data, data.test.labels
     X_val, y_val = data.validation.data, data.validation.labels
-
-    datagen = ImageDataGenerator(
-        rotation_range=20,
-        width_shift_range=0.2,
-        height_shift_range=0.2,
-        horizontal_flip=True)
     datagen.fit(X_train)
 
     paraman.init_kernel_dict(X_train)
-- 
GitLab