diff --git a/.gitignore b/.gitignore
index 25edce6a975cd781255e0291d7d4199b671a1b64..1df246617e40f1c0e38a0d80f3483a7ae009a721 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,5 @@
 .idea
+MNIST_data/
 
 # Byte-compiled / optimized / DLL files
 __pycache__/
diff --git a/main/convnet_random.py b/main/convnet_random.py
index b6b11c8256fa7d42cd7ae1351576215ae7c29c1b..2b1bc1332bf289997b64040155d686a4e2ea1dda 100644
--- a/main/convnet_random.py
+++ b/main/convnet_random.py
@@ -4,6 +4,10 @@ Convolutional Neural Netwok implementation in tensorflow whith multiple represen
     - Random Fourier Features layer
     - Fast Food layer where Fast Hadamard Transform has been replaced by dot product with Hadamard matrix.
 
+See:
+"Deep Fried Convnets" by
+Zichao Yang, Marcin Moczulski, Misha Denil, Nando de Freitas, Alex Smola, Le Song, Ziyu Wang
+
 """
 
 import tensorflow as tf
@@ -77,59 +81,55 @@ def random_biases(shape):
 
 # --- Fast Food Naive --- #
 
-def G_variable(d, diag=True, trainable=False):
+def G_variable(shape, trainable=False):
     """
-    Return a Gaussian Random diagonal matrix converted into Tensorflow Variable.
+    Return a Gaussian Random matrix converted into Tensorflow Variable.
 
-    :param d: The size of the diagonal
-    :type d: int
-    :return: tf.Variable object containing the diagonal and not trainable, The frobenius norm of this diagonal (float)
+    :param shape: The shape of the matrix (number of fastfood stacks (v), dimension of the input space (d))
+    :type shape: int or tuple of int (tuple size = 2)
+    :return: tf.Variable object containing the matrix, The norm2 of each line (np.array of float)
     """
+    assert type(shape) == int or (type(shape) == tuple and len(shape) == 2)
+    G = np.random.normal(size=shape).astype(np.float32)
+    G_norms = np.linalg.norm(G, ord=2, axis=1)
+    return tf.Variable(G, name="G", trainable=trainable), G_norms
 
-    if diag:
-        G = np.diag(np.random.normal(size=d)).astype(np.float32)
-        G_norm = np.linalg.norm(G, ord='fro')
-    else:
-        G = np.random.normal(size=d).astype(np.float32)
-        G_norm = np.linalg.norm(G, ord=2)
-    return tf.Variable(G, name="G", trainable=trainable), G_norm
 
-
-def B_variable(d, diag=True, trainable=False):
+def B_variable(shape, trainable=False):
     """
-    Return a random diagonal matrix of -1 and 1 picked uniformly into Tensorflow Variable.
+    Return a random matrix of -1 and 1 picked uniformly and converted into Tensorflow Variable.
 
-    :param d: The size of the diagonal
-    :type d: int
-    :return: tf.Variable object containing the diagonal and not trainable
+    :param shape: The shape of the matrix (number of fastfood stacks (v), dimension of the input space (d))
+    :type shape: int or tuple of int (tuple size = 2)
+    :return: tf.Variable object containing the matrix
     """
-    if diag:
-        B = np.diag(np.random.choice([-1, 1], size=d, replace=True)).astype(np.float32)
-    else:
-        B = np.random.choice([-1, 1], size=d, replace=True).astype(np.float32)
+    assert type(shape) == int or (type(shape) == tuple and len(shape) == 2)
+    B = np.random.choice([-1, 1], size=shape, replace=True).astype(np.float32)
     return tf.Variable(B, name="B", trainable=trainable)
 
 
-def P_variable(d):
+def P_variable(d, nbr_stack):
     """
-    Return a permutation matrix into Tensorflow Variable.
+    Return a permutation matrix converted into Tensorflow Variable.
 
-    :param d: The size of the diagonal
+    :param d: The width of the matrix (dimension of the input space)
     :type d: int
-    :return: tf.Variable object containing the diagonal and not trainable
+    :param nbr_stack: The height of the matrix (nbr_stack x d is the dimension of the output space)
+    :type nbr_stack: int
+    :return: tf.Variable object containing the matrix
     """
-    idx = np.random.permutation(d)
-    P = np.random.permutation(np.eye(N=d))[idx].astype(np.float32)
+    idx = [(i * d) + np.random.permutation(d) for i in range(nbr_stack)]
+    P = np.random.permutation(np.eye(N=nbr_stack * d))[idx].astype(np.float32)
     return tf.Variable(P, name="P", trainable=False)
 
 
 def H_variable(d):
     """
-    Return an Hadamard matrix into Tensorflow Variable.
+    Return an Hadamard matrix converted into Tensorflow Variable.
 
     d must be a power of two.
 
-    :param d: The size of the Hadamard matrix.
+    :param d: The size of the Hadamard matrix (dimension of the input space).
     :type d: int
     :return: tf.Variable object containing the diagonal and not trainable
     """
@@ -137,22 +137,20 @@ def H_variable(d):
     return tf.Variable(H, name="H", trainable=False)
 
 
-def S_variable(d, G_norm, diag=True, trainable=False):
+def S_variable(shape, G_norms, trainable=False):
     """
-    Return a scaling diagonal matrix of random values picked from a chi distribution.
+    Return a scaling matrix of random values picked from a chi distribution.
 
-    The values are re-scaled using the norm of the Gaussian Diagonal random matrix G.
+    The values are re-scaled using the norm of the associated Gaussian random matrix G. The associated Gaussian
+    vectors are the ones generated by the `G_variable` function.
 
-    :param d: The size of the diagonal.
-    :type d: int
-    :param G_norm: The norm of the Gaussian Diagonal random matrix G.
-    :type G_norm: float
-    :return: tf.Variable object containing the diagonal and not trainable.
+    :param shape: The shape of the matrix (number of fastfood stacks (v), dimension of the input space (d))
+    :type shape: int or tuple of int (tuple size = 2)
+    :param G_norms: The norms of the associated Gaussian random matrices G.
+    :type G_norms: np.array of floats
+    :return: tf.Variable object containing the matrix.
     """
-    if diag:
-        S = np.diag((1 / G_norm) * scipy.stats.chi.rvs(d, size=d)).astype(np.float32)
-    else:
-        S = (1 / G_norm) * scipy.stats.chi.rvs(d, size=d).astype(np.float32)
+    S = np.multiply((1 / G_norms.reshape((-1, 1))), scipy.stats.chi.rvs(shape[1], size=shape).astype(np.float32))
     return tf.Variable(S, name="S", trainable=trainable)
 
 
@@ -197,8 +195,23 @@ def random_features(conv_out, sigma):
         return h1_final
 
 
-def fast_food(conv_out, sigma, nbr_stack=1, diag=True, trainable=False, name="fastfood"):
-    with tf.name_scope(name + "_diag=" + str(diag) + "_sigma=" + str(sigma)):
+def fast_food(conv_out, sigma, nbr_stack=1, trainable=False):
+    """
+    Return a fastfood transform op compatible with tensorflow graph.
+
+    Implementation largely inspired from https://gist.github.com/dougalsutherland/1a3c70e57dd1f64010ab .
+
+    See:
+    "Fastfood | Approximating Kernel Expansions in Loglinear Time" by
+    Quoc Le, Tamas Sarl and Alex Smola.
+
+    :param conv_out: the input of the op
+    :param sigma: bandwith of the gaussian distribution
+    :param nbr_stack: number of fast food stacks
+    :param trainable: the diagonal matrices are trainable or not
+    :return: the output of the fastfood transform
+    """
+    with tf.name_scope("fastfood" + "_sigma-"+str(sigma)):
         init_dim = np.prod([s.value for s in conv_out.shape if s.value is not None])
         final_dim = int(dimensionality_constraints(init_dim))
         padding = final_dim - init_dim
@@ -206,40 +219,31 @@ def fast_food(conv_out, sigma, nbr_stack=1, diag=True, trainable=False, name="fa
         paddings = tf.constant([[0, 0], [0, padding]])
         conv_out2 = tf.pad(conv_out2, paddings, "CONSTANT")
 
-        G, G_norm = G_variable(final_dim, diag=diag, trainable=trainable)
+        G, G_norm = G_variable((nbr_stack, final_dim), trainable=trainable)
         tf.summary.histogram("weights G", G)
-        B = B_variable(final_dim, diag=diag, trainable=trainable)
+        B = B_variable((nbr_stack, final_dim), trainable=trainable)
         tf.summary.histogram("weights B", B)
         H = H_variable(final_dim)
         tf.summary.histogram("weights H", H)
-        P = P_variable(final_dim)
+        P = P_variable(final_dim, nbr_stack)
         tf.summary.histogram("weights P", P)
-        S = S_variable(final_dim, G_norm, diag=diag, trainable=trainable)
+        S = S_variable((nbr_stack, final_dim), G_norm, trainable=trainable)
         tf.summary.histogram("weights S", S)
 
-        if diag:
-            h_ff1 = tf.matmul(conv_out2, B, name="Bx")
-            h_ff2 = tf.matmul(h_ff1, H, name="HBx")
-            h_ff3 = tf.matmul(h_ff2, P, name="PHBx")
-            h_ff4 = tf.matmul(h_ff3, G, name="GPHBx")
-            h_ff5 = tf.matmul(h_ff4, H, name="HGPHBx")
-
-            h_ff6 = tf.scalar_mul((1/(sigma * np.sqrt(final_dim))), tf.matmul(h_ff5, S, name="SHGPHBx"))
-            h_ff7_1 = tf.cos(h_ff6)
-            h_ff7_2 = tf.sin(h_ff6)
-            h_ff7 = tf.scalar_mul(tf.sqrt(float(1 / final_dim)), tf.concat([h_ff7_1, h_ff7_2], axis=1))
-
-        else:
-            h_ff1 = tf.multiply(conv_out2, B, name="Bx")
-            h_ff2 = tf.matmul(h_ff1, H, name="HBx")
-            h_ff3 = tf.matmul(h_ff2, P, name="PHBx")
-            h_ff4 = tf.multiply(h_ff3, G, name="GPHBx")
-            h_ff5 = tf.matmul(h_ff4, H, name="HGPHBx")
-
-            h_ff6 = tf.scalar_mul((1/(sigma * np.sqrt(final_dim))), tf.multiply(h_ff5, S, name="SHGPHBx"))
-            h_ff7_1 = tf.cos(h_ff6)
-            h_ff7_2 = tf.sin(h_ff6)
-            h_ff7 = tf.scalar_mul(tf.sqrt(float(1 / final_dim)), tf.concat([h_ff7_1, h_ff7_2], axis=1))
+        conv_out2 = tf.reshape(conv_out2, (1, -1, 1, final_dim))
+        h_ff1 = tf.multiply(conv_out2, B, name="Bx")
+        h_ff1 = tf.reshape(h_ff1, (-1, final_dim))
+        h_ff2 = tf.matmul(h_ff1, H, name="HBx")
+        h_ff2 = tf.reshape(h_ff2, (-1, final_dim * nbr_stack))
+        h_ff3 = tf.matmul(h_ff2, P, name="PHBx")
+        h_ff4 = tf.multiply(tf.reshape(h_ff3, (-1, final_dim * nbr_stack)), tf.reshape(G, (-1, final_dim * nbr_stack)), name="GPHBx")
+        h_ff4 = tf.reshape(h_ff4, (-1, final_dim))
+        h_ff5 = tf.matmul(h_ff4, H, name="HGPHBx")
+
+        h_ff6 = tf.scalar_mul((1/(sigma * np.sqrt(final_dim))), tf.multiply(tf.reshape(h_ff5, (-1, final_dim * nbr_stack)), tf.reshape(S, (-1, final_dim * nbr_stack)), name="SHGPHBx"))
+        h_ff7_1 = tf.cos(h_ff6)
+        h_ff7_2 = tf.sin(h_ff6)
+        h_ff7 = tf.scalar_mul(tf.sqrt(float(1 / final_dim)), tf.concat([h_ff7_1, h_ff7_2], axis=1))
     return h_ff7
 
 
@@ -255,15 +259,6 @@ def fully_connected(conv_out):
     return h_fc1
 
 
-def stacked_fastfood(input_, nbr, sigma, diag=False, trainable=False):
-    l_outputs = []
-    for i in range(nbr):
-        output = fast_food(input_, sigma, diag=diag, trainable=trainable, name="fastfood" + str(i))
-        l_outputs.append(output)
-    outputs_stacked = tf.concat(l_outputs, axis=1)
-    return outputs_stacked
-
-
 if __name__ == '__main__':
     SIGMA = 5.0
     print("Sigma = {}".format(SIGMA))
@@ -281,12 +276,10 @@ if __name__ == '__main__':
         h_conv = convolution(x_image)
         # h_conv = x
         # out_fc = fully_connected(h_conv)  # 95% accuracy
-        # out_fc = tf.nn.relu(fast_food(h_conv, SIGMA))  # 83% accuracy (conv) | 56% accuracy (noconv)
-        # out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, diag=False))  # 84% accuracy (conv) | 59% accuracy (noconv)
-        # out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, diag=False, trainable=True))  # 84% accuracy (conv) | 59% accuracy (noconv)
-        # todo: faire une implémentation moins naive: il doit y avoir des blocs dans tf uniquement lorsque j'utilise des matrices
-        # diagonales, sinon je n'ai besoin que de plusieurs lignes pour la matrice de hadamard
-        out_fc = tf.nn.relu(stacked_fastfood(h_conv, 2, SIGMA, diag=False, trainable=True))  # 84% accuracy (conv) | 59% accuracy (noconv)
+        # out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, nbr_stack=1))  # 83% accuracy (conv) | 56% accuracy (noconv)
+        out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, nbr_stack=2))
+        # out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, nbr_stack=2, trainable=True))
+        # out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, trainable=True))  # 84% accuracy (conv) | 59% accuracy (noconv)
         # out_fc = fast_food(h_conv, SIGMA, diag=True, trainable=True)  # 84% accuracy (conv) | 59% accuracy (noconv)
         # out_fc = random_features(h_conv, SIGMA)  # 82% accuracy (conv) | 47% accuracy (noconv)
 
@@ -333,7 +326,7 @@ if __name__ == '__main__':
         sess.run(init)
         # actual learning
         started = t.time()
-        for i in range(2000):
+        for i in range(20000):
             batch = mnist.train.next_batch(64)
             feed_dict = {x: batch[0], y_: batch[1], keep_prob: 0.5}
             # le _ est pour capturer le retour de "train_optimizer" qu'il faut appeler