From c4fd292c6edc6cfc6f085b4fffe1caa82ef76669 Mon Sep 17 00:00:00 2001 From: Luc Giffon <luc.giffon@lis-lab.fr> Date: Thu, 29 Nov 2018 18:30:56 +0100 Subject: [PATCH] Fix issues deepfriedconvnet and mnist No more horizontal flip for mnist and svhn datasets (horizontal flips on number = not good) Now gamma is computed for deepfriedconvnet (when it was only done for deepstrom with rbf in the former version) --- .../deepstrom_classif_end_to_end.py | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/main/experiments/scripts/november_2018/end_to_end_with_augment/solve_mnist_and_deepfried_issues/deepstrom_classif_end_to_end.py b/main/experiments/scripts/november_2018/end_to_end_with_augment/solve_mnist_and_deepfried_issues/deepstrom_classif_end_to_end.py index 903d054..49f5953 100644 --- a/main/experiments/scripts/november_2018/end_to_end_with_augment/solve_mnist_and_deepfried_issues/deepstrom_classif_end_to_end.py +++ b/main/experiments/scripts/november_2018/end_to_end_with_augment/solve_mnist_and_deepfried_issues/deepstrom_classif_end_to_end.py @@ -54,21 +54,19 @@ Kernel related: """ -import skluc.main.data.mldatasets as dataset +import docopt import numpy as np import tensorflow as tf +import time as t from tensorflow.python.keras.layers import Dense -from tensorflow.python.keras.regularizers import l2 -from tensorflow.python.keras.initializers import he_normal from tensorflow.python.keras.preprocessing.image import ImageDataGenerator +import skluc.main.data.mldatasets as dataset from skluc.main.tensorflow_.kernel_approximation.fastfood_layer import FastFoodLayer from skluc.main.tensorflow_.kernel_approximation.nystrom_layer import DeepstromLayerEndToEnd from skluc.main.tensorflow_.models import build_lenet_model, build_vgg19_model -from skluc.main.utils import logger, memory_usage, ParameterManager, ResultManager, ResultPrinter from skluc.main.tensorflow_.utils import batch_generator -import time as t -import docopt +from skluc.main.utils import logger, memory_usage, ParameterManager, ResultManager, ResultPrinter class ParameterManagerMain(ParameterManager): @@ -101,7 +99,7 @@ class ParameterManagerMain(ParameterManager): return None def init_kernel_dict(self, data): - if self["kernel"] == "rbf": + if self["kernel"] == "rbf" or self["network"] == "deepfriedconvnet": GAMMA = self.get_gamma_value(data) self["--gamma"] = GAMMA self.__kernel_dict = {"gamma": GAMMA} @@ -136,15 +134,35 @@ def main(paraman, resman, printman): if paraman["dataset"] == "mnist": data = dataset.MnistDataset(validation_size=paraman["--validation-size"], seed=paraman["--seed"]) convmodel_func = build_lenet_model + datagen = ImageDataGenerator( + rotation_range=20, + width_shift_range=0.2, + height_shift_range=0.2, + horizontal_flip=False) elif paraman["dataset"] == "cifar10": data = dataset.Cifar10Dataset(validation_size=paraman["--validation-size"], seed=paraman["--seed"]) convmodel_func = build_vgg19_model + datagen = ImageDataGenerator( + rotation_range=20, + width_shift_range=0.2, + height_shift_range=0.2, + horizontal_flip=True) elif paraman["dataset"] == "cifar100": data = dataset.Cifar100FineDataset(validation_size=paraman["--validation-size"], seed=paraman["--seed"]) convmodel_func = build_vgg19_model + datagen = ImageDataGenerator( + rotation_range=20, + width_shift_range=0.2, + height_shift_range=0.2, + horizontal_flip=True) elif paraman["dataset"] == "svhn": data = dataset.SVHNDataset(validation_size=paraman["--validation-size"], seed=paraman["--seed"]) convmodel_func = build_vgg19_model + datagen = ImageDataGenerator( + rotation_range=20, + width_shift_range=0.2, + height_shift_range=0.2, + horizontal_flip=False) else: raise ValueError("Unknown dataset") @@ -159,12 +177,6 @@ def main(paraman, resman, printman): X_train, y_train = data.train.data, data.train.labels X_test, y_test = data.test.data, data.test.labels X_val, y_val = data.validation.data, data.validation.labels - - datagen = ImageDataGenerator( - rotation_range=20, - width_shift_range=0.2, - height_shift_range=0.2, - horizontal_flip=True) datagen.fit(X_train) paraman.init_kernel_dict(X_train) -- GitLab