Skip to content
Snippets Groups Projects
Commit c4fd292c authored by Luc Giffon's avatar Luc Giffon
Browse files

Fix issues deepfriedconvnet and mnist

No more horizontal flip for mnist and svhn datasets (horizontal flips on number = not good)
Now gamma is computed for deepfriedconvnet (when it was only done for deepstrom with rbf in the former version)
parent 9663721d
No related branches found
No related tags found
No related merge requests found
......@@ -54,21 +54,19 @@ Kernel related:
"""
import skluc.main.data.mldatasets as dataset
import docopt
import numpy as np
import tensorflow as tf
import time as t
from tensorflow.python.keras.layers import Dense
from tensorflow.python.keras.regularizers import l2
from tensorflow.python.keras.initializers import he_normal
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
import skluc.main.data.mldatasets as dataset
from skluc.main.tensorflow_.kernel_approximation.fastfood_layer import FastFoodLayer
from skluc.main.tensorflow_.kernel_approximation.nystrom_layer import DeepstromLayerEndToEnd
from skluc.main.tensorflow_.models import build_lenet_model, build_vgg19_model
from skluc.main.utils import logger, memory_usage, ParameterManager, ResultManager, ResultPrinter
from skluc.main.tensorflow_.utils import batch_generator
import time as t
import docopt
from skluc.main.utils import logger, memory_usage, ParameterManager, ResultManager, ResultPrinter
class ParameterManagerMain(ParameterManager):
......@@ -101,7 +99,7 @@ class ParameterManagerMain(ParameterManager):
return None
def init_kernel_dict(self, data):
if self["kernel"] == "rbf":
if self["kernel"] == "rbf" or self["network"] == "deepfriedconvnet":
GAMMA = self.get_gamma_value(data)
self["--gamma"] = GAMMA
self.__kernel_dict = {"gamma": GAMMA}
......@@ -136,15 +134,35 @@ def main(paraman, resman, printman):
if paraman["dataset"] == "mnist":
data = dataset.MnistDataset(validation_size=paraman["--validation-size"], seed=paraman["--seed"])
convmodel_func = build_lenet_model
datagen = ImageDataGenerator(
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=False)
elif paraman["dataset"] == "cifar10":
data = dataset.Cifar10Dataset(validation_size=paraman["--validation-size"], seed=paraman["--seed"])
convmodel_func = build_vgg19_model
datagen = ImageDataGenerator(
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True)
elif paraman["dataset"] == "cifar100":
data = dataset.Cifar100FineDataset(validation_size=paraman["--validation-size"], seed=paraman["--seed"])
convmodel_func = build_vgg19_model
datagen = ImageDataGenerator(
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True)
elif paraman["dataset"] == "svhn":
data = dataset.SVHNDataset(validation_size=paraman["--validation-size"], seed=paraman["--seed"])
convmodel_func = build_vgg19_model
datagen = ImageDataGenerator(
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=False)
else:
raise ValueError("Unknown dataset")
......@@ -159,12 +177,6 @@ def main(paraman, resman, printman):
X_train, y_train = data.train.data, data.train.labels
X_test, y_test = data.test.data, data.test.labels
X_val, y_val = data.validation.data, data.validation.labels
datagen = ImageDataGenerator(
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True)
datagen.fit(X_train)
paraman.init_kernel_dict(X_train)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment