diff --git a/keras_kernel_functions.py b/keras_kernel_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ebd4cbe367f7d081db864d3457bde920a9c25ab
--- /dev/null
+++ b/keras_kernel_functions.py
@@ -0,0 +1,94 @@
+"""
+Module containing popular kernel functions using the keras API.
+"""
+
+from keras import backend as K
+from keras.activations import tanh
+import tensorflow as tf
+
+def replace_nan(tensor):
+    return tf.where(tf.is_nan(tensor), tf.zeros_like(tensor), tensor)
+
+
+def keras_linear_kernel(args, normalize=True, tanh_activation=False):
+    """
+    Linear kernel:
+
+    $k(x, y) = x^Ty$
+
+    :param args: list of size 2 containing x and y
+    :param normalize: if True, normalize the input with l2 before computing the kernel function
+    :param tanh_activation: if True apply tanh activation to the output
+    :return:
+    """
+    X = args[0]
+    Y = args[1]
+    if normalize:
+        X = K.l2_normalize(X, axis=-1)
+        Y = K.l2_normalize(Y, axis=-1)
+    result = K.dot(X, K.transpose(Y))
+    if tanh_activation:
+        return tanh(result)
+    else:
+        return result
+
+
+
+def keras_chi_square_CPD(args, epsilon=None, tanh_activation=True, normalize=False):
+    X = args[0]
+    Y = args[1]
+    if normalize:
+        X = K.l2_normalize(X, axis=-1)
+        Y = K.l2_normalize(Y, axis=-1)
+    # the drawing of the matrix X expanded looks like a wall
+    wall = K.expand_dims(X, axis=1)
+    # the drawing of the matrix Y expanded looks like a floor
+    floor = K.expand_dims(Y, axis=0)
+    numerator = K.square((wall - floor))
+    denominator = wall + floor
+    if epsilon is not None:
+        quotient = numerator / (denominator + epsilon)
+    else:
+        quotient = numerator / denominator
+    quotient_without_nan = replace_nan(quotient)
+    result = - K.sum(quotient_without_nan, axis=2)
+    if tanh_activation:
+        return tanh(result)
+    else:
+        return result
+
+
+def keras_chi_square_CPD_exp(args, gamma, epsilon=None, tanh_activation=False, normalize=True):
+    result = keras_chi_square_CPD(args, epsilon, tanh_activation, normalize)
+    result *= gamma
+    return K.exp(result)
+
+
+def keras_rbf_kernel(args, gamma, normalize=True, tanh_activation=False):
+    """
+    Compute the rbf kernel between each entry of X and each line of Y.
+
+    tf_rbf_kernel(x, y, gamma) = exp(- (||x - y||^2 * gamma))
+
+    :param X: A tensor of size n times d
+    :param Y: A tensor of size m times d
+    :param gamma: The bandwith of the kernel
+    :return:
+    """
+    X = args[0]
+    Y = args[1]
+    if normalize:
+        X = K.l2_normalize(X, axis=-1)
+        Y = K.l2_normalize(Y, axis=-1)
+    r1 = K.sum(X * X, axis=1)
+    r1 = K.reshape(r1, [-1, 1])
+    r2 = K.sum(Y * Y, axis=1)
+    r2 = K.reshape(r2, [1, -1])
+    result = K.dot(X, K.transpose(Y))
+    result = r1 - 2 * result + r2
+    result *= -gamma
+    result = K.exp(result)
+    if tanh_activation:
+        return tanh(result)
+    else:
+        return result
\ No newline at end of file
diff --git a/nystrom_layer.py b/nystrom_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebeb7f5c4f1d44e758d366108e52b89372ed9555
--- /dev/null
+++ b/nystrom_layer.py
@@ -0,0 +1,136 @@
+import keras
+import numpy as np
+from keras.datasets import cifar10
+from keras.models import Sequential, Model
+from keras.layers import Dense, BatchNormalization, Flatten, Lambda, Input, Lambda, concatenate, Activation
+from keras.layers import Conv2D, MaxPooling2D
+from keras.optimizers import Adam
+from keras.preprocessing.image import ImageDataGenerator
+
+from keras_kernel_functions import keras_linear_kernel
+
+
+def datagen_fixed_batch_size(x, y, x_sub=None, p_datagen=ImageDataGenerator()):
+    if x_sub is None:
+        x_sub = []
+    for x_batch, y_batch in p_datagen.flow(x, y, batch_size=batch_size):
+        if x_batch.shape[0] != batch_size:
+            continue
+        yield [x_batch] + x_sub, y_batch
+
+def build_conv_model(input_shape):
+    """
+    Create a simple sequential convolutional model
+
+    :param input_shape: tuple containing the expected input data shape
+    :return: keras model object
+    """
+
+    model = Sequential()
+
+    model.add(Conv2D(16, (3, 3), padding='same', input_shape=input_shape))
+    model.add(BatchNormalization())
+    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
+    model.add(Activation('relu'))
+    model.add(Conv2D(16, (3, 3), padding='same'))
+    model.add(BatchNormalization())
+    model.add(Activation('relu'))
+    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
+
+    return model
+
+def init_number_subsample_bases(nys_size, batch_size):
+    """
+    Return the number of bases and the size of the zero padding for initialization of the model.
+
+    :param nys_size: The number of subsample in the Nystrom approximation.
+    :param batch_size: The batch size in the final model.
+    :return: number of bases, size of the zero padding
+    """
+    remaining = nys_size % batch_size
+    quotient = nys_size // batch_size
+    if nys_size == 0 or batch_size == 0:
+        raise ValueError
+    if remaining == 0:
+        return quotient, 0
+    elif quotient == 0:
+        return 1, batch_size - remaining
+    else:
+        return quotient + 1, batch_size - remaining
+
+if __name__ == "__main__":
+    batch_size = 128
+    epochs = 1
+    num_classes = 10
+    nys_size = 8
+
+    # data preparation
+    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
+    y_train = keras.utils.to_categorical(y_train, num_classes)
+    y_test = keras.utils.to_categorical(y_test, num_classes)
+    x_train = x_train.astype('float32')
+    x_test = x_test.astype('float32')
+    x_train /= 255
+    x_test /= 255
+    datagen = ImageDataGenerator(
+        rotation_range=20,
+        width_shift_range=0.2,
+        height_shift_range=0.2,
+        horizontal_flip=True)
+    datagen.fit(x_train)
+
+    # subsample for nystrom layer preparation
+    # ---------------------------------------
+    # keras needs all its input to have the same shape. The subsample input to the model is then divided in so called "bases" of the same size than the batch, all stored in a list.
+    # The last base may not be full of samples so it mmust be padded with zeros. Those zeros will be cut off in the model computation.
+    # If you have a suggestion on how to better implement it, feel free to suggest.
+    nb_subsample_bases, zero_padding_base = init_number_subsample_bases(nys_size, batch_size)
+    subsample_indexes = np.random.permutation(x_train.shape[0])[:nys_size]
+    nys_subsample = x_train[subsample_indexes]
+    zero_padding_subsample = np.zeros((zero_padding_base, *nys_subsample.shape[1:]))
+    nys_subsample = np.vstack([nys_subsample, zero_padding_subsample])
+    list_subsample_bases = [nys_subsample[i * batch_size:(i + 1) * batch_size] for i in range(nb_subsample_bases)]
+
+    # convolution layers preparation
+    # ------------------------------
+    convmodel_func = build_conv_model(x_train[0].shape)  # type: keras.models.Sequential
+    convmodel_func.add(Flatten())
+
+
+    # processing of the input by the convolution
+    # ------------------------------------------
+    input_x = Input(shape=x_train[0].shape, name="x")
+    conv_x = convmodel_func(input_x)
+
+    # processing of the subsample by the convolution
+    # ----------------------------------------------
+    # definition of the list of input bases
+    input_repr_subsample = [Input(batch_shape=(batch_size, *x_train[0].shape)) for _ in range(nb_subsample_bases)]
+
+    if nb_subsample_bases > 1:
+        input_subsample_concat = concatenate(input_repr_subsample, axis=0)
+    else:
+        input_subsample_concat = input_repr_subsample[0]
+
+    # remove the zeros from the input subsamplebefore actual computation in the network
+    slice_layer = Lambda(lambda input: input[:nys_size], output_shape=lambda shape: (nys_size, *shape[1:]))
+    input_subsample_concat = slice_layer(input_subsample_concat)
+    conv_subsample = convmodel_func(input_subsample_concat)
+
+    # definition of the nystrom layer
+    # -------------------------------
+    kernel_function = lambda *args, **kwargs: keras_linear_kernel(*args, **kwargs, normalize=True)
+    # kernel function as Lambda layer
+    kernel_layer = Lambda(kernel_function, output_shape=lambda shapes: (shapes[0][0], nys_size))
+    kernel_vector = kernel_layer([conv_x, conv_subsample])
+    # weight matrix of the nystrom layer
+    input_classifier = Dense(nys_size, use_bias=False, activation='linear')(kernel_vector) # metric matrix of the Nyström layer
+
+    classif = Dense(num_classes, activation="softmax")(input_classifier)
+
+    model = Model([input_x] + input_repr_subsample, [classif])
+    adam = Adam(lr=.1)
+    model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])
+    model.fit_generator(datagen_fixed_batch_size(x_train, y_train, list_subsample_bases, datagen),
+                        steps_per_epoch=int(x_train.shape[0] / batch_size),
+                        epochs=epochs)