diff --git a/main/experiments/scripts/november_2018/end_to_end_with_augment/solve_mnist_and_deepfried_issues/deepstrom_classif_end_to_end.py b/main/experiments/scripts/november_2018/end_to_end_with_augment/solve_mnist_and_deepfried_issues/deepstrom_classif_end_to_end.py index 903d0541d61d7456954d343c467b09c76e4ef99b..49f5953b379d7733e9b22eb11eced95454ff729f 100644 --- a/main/experiments/scripts/november_2018/end_to_end_with_augment/solve_mnist_and_deepfried_issues/deepstrom_classif_end_to_end.py +++ b/main/experiments/scripts/november_2018/end_to_end_with_augment/solve_mnist_and_deepfried_issues/deepstrom_classif_end_to_end.py @@ -54,21 +54,19 @@ Kernel related: """ -import skluc.main.data.mldatasets as dataset +import docopt import numpy as np import tensorflow as tf +import time as t from tensorflow.python.keras.layers import Dense -from tensorflow.python.keras.regularizers import l2 -from tensorflow.python.keras.initializers import he_normal from tensorflow.python.keras.preprocessing.image import ImageDataGenerator +import skluc.main.data.mldatasets as dataset from skluc.main.tensorflow_.kernel_approximation.fastfood_layer import FastFoodLayer from skluc.main.tensorflow_.kernel_approximation.nystrom_layer import DeepstromLayerEndToEnd from skluc.main.tensorflow_.models import build_lenet_model, build_vgg19_model -from skluc.main.utils import logger, memory_usage, ParameterManager, ResultManager, ResultPrinter from skluc.main.tensorflow_.utils import batch_generator -import time as t -import docopt +from skluc.main.utils import logger, memory_usage, ParameterManager, ResultManager, ResultPrinter class ParameterManagerMain(ParameterManager): @@ -101,7 +99,7 @@ class ParameterManagerMain(ParameterManager): return None def init_kernel_dict(self, data): - if self["kernel"] == "rbf": + if self["kernel"] == "rbf" or self["network"] == "deepfriedconvnet": GAMMA = self.get_gamma_value(data) self["--gamma"] = GAMMA self.__kernel_dict = {"gamma": GAMMA} @@ -136,15 +134,35 @@ def main(paraman, resman, printman): if paraman["dataset"] == "mnist": data = dataset.MnistDataset(validation_size=paraman["--validation-size"], seed=paraman["--seed"]) convmodel_func = build_lenet_model + datagen = ImageDataGenerator( + rotation_range=20, + width_shift_range=0.2, + height_shift_range=0.2, + horizontal_flip=False) elif paraman["dataset"] == "cifar10": data = dataset.Cifar10Dataset(validation_size=paraman["--validation-size"], seed=paraman["--seed"]) convmodel_func = build_vgg19_model + datagen = ImageDataGenerator( + rotation_range=20, + width_shift_range=0.2, + height_shift_range=0.2, + horizontal_flip=True) elif paraman["dataset"] == "cifar100": data = dataset.Cifar100FineDataset(validation_size=paraman["--validation-size"], seed=paraman["--seed"]) convmodel_func = build_vgg19_model + datagen = ImageDataGenerator( + rotation_range=20, + width_shift_range=0.2, + height_shift_range=0.2, + horizontal_flip=True) elif paraman["dataset"] == "svhn": data = dataset.SVHNDataset(validation_size=paraman["--validation-size"], seed=paraman["--seed"]) convmodel_func = build_vgg19_model + datagen = ImageDataGenerator( + rotation_range=20, + width_shift_range=0.2, + height_shift_range=0.2, + horizontal_flip=False) else: raise ValueError("Unknown dataset") @@ -159,12 +177,6 @@ def main(paraman, resman, printman): X_train, y_train = data.train.data, data.train.labels X_test, y_test = data.test.data, data.test.labels X_val, y_val = data.validation.data, data.validation.labels - - datagen = ImageDataGenerator( - rotation_range=20, - width_shift_range=0.2, - height_shift_range=0.2, - horizontal_flip=True) datagen.fit(X_train) paraman.init_kernel_dict(X_train)