From bc0cbecffd36ae83c6720a6d0385a87e6fbb144e Mon Sep 17 00:00:00 2001 From: Luc Giffon <luc.giffon@lif.univ-mrs.fr> Date: Wed, 29 Nov 2017 17:43:17 +0100 Subject: [PATCH] removed diagonal implementation of fastfood (only hadamard product version remaining) + implementation of the stacked fastfood for the hadamard product version - need command line execution support and then it will be good --- .gitignore | 1 + main/convnet_random.py | 167 ++++++++++++++++++++--------------------- 2 files changed, 81 insertions(+), 87 deletions(-) diff --git a/.gitignore b/.gitignore index 25edce6..1df2466 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .idea +MNIST_data/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/main/convnet_random.py b/main/convnet_random.py index b6b11c8..2b1bc13 100644 --- a/main/convnet_random.py +++ b/main/convnet_random.py @@ -4,6 +4,10 @@ Convolutional Neural Netwok implementation in tensorflow whith multiple represen - Random Fourier Features layer - Fast Food layer where Fast Hadamard Transform has been replaced by dot product with Hadamard matrix. +See: +"Deep Fried Convnets" by +Zichao Yang, Marcin Moczulski, Misha Denil, Nando de Freitas, Alex Smola, Le Song, Ziyu Wang + """ import tensorflow as tf @@ -77,59 +81,55 @@ def random_biases(shape): # --- Fast Food Naive --- # -def G_variable(d, diag=True, trainable=False): +def G_variable(shape, trainable=False): """ - Return a Gaussian Random diagonal matrix converted into Tensorflow Variable. + Return a Gaussian Random matrix converted into Tensorflow Variable. - :param d: The size of the diagonal - :type d: int - :return: tf.Variable object containing the diagonal and not trainable, The frobenius norm of this diagonal (float) + :param shape: The shape of the matrix (number of fastfood stacks (v), dimension of the input space (d)) + :type shape: int or tuple of int (tuple size = 2) + :return: tf.Variable object containing the matrix, The norm2 of each line (np.array of float) """ + assert type(shape) == int or (type(shape) == tuple and len(shape) == 2) + G = np.random.normal(size=shape).astype(np.float32) + G_norms = np.linalg.norm(G, ord=2, axis=1) + return tf.Variable(G, name="G", trainable=trainable), G_norms - if diag: - G = np.diag(np.random.normal(size=d)).astype(np.float32) - G_norm = np.linalg.norm(G, ord='fro') - else: - G = np.random.normal(size=d).astype(np.float32) - G_norm = np.linalg.norm(G, ord=2) - return tf.Variable(G, name="G", trainable=trainable), G_norm - -def B_variable(d, diag=True, trainable=False): +def B_variable(shape, trainable=False): """ - Return a random diagonal matrix of -1 and 1 picked uniformly into Tensorflow Variable. + Return a random matrix of -1 and 1 picked uniformly and converted into Tensorflow Variable. - :param d: The size of the diagonal - :type d: int - :return: tf.Variable object containing the diagonal and not trainable + :param shape: The shape of the matrix (number of fastfood stacks (v), dimension of the input space (d)) + :type shape: int or tuple of int (tuple size = 2) + :return: tf.Variable object containing the matrix """ - if diag: - B = np.diag(np.random.choice([-1, 1], size=d, replace=True)).astype(np.float32) - else: - B = np.random.choice([-1, 1], size=d, replace=True).astype(np.float32) + assert type(shape) == int or (type(shape) == tuple and len(shape) == 2) + B = np.random.choice([-1, 1], size=shape, replace=True).astype(np.float32) return tf.Variable(B, name="B", trainable=trainable) -def P_variable(d): +def P_variable(d, nbr_stack): """ - Return a permutation matrix into Tensorflow Variable. + Return a permutation matrix converted into Tensorflow Variable. - :param d: The size of the diagonal + :param d: The width of the matrix (dimension of the input space) :type d: int - :return: tf.Variable object containing the diagonal and not trainable + :param nbr_stack: The height of the matrix (nbr_stack x d is the dimension of the output space) + :type nbr_stack: int + :return: tf.Variable object containing the matrix """ - idx = np.random.permutation(d) - P = np.random.permutation(np.eye(N=d))[idx].astype(np.float32) + idx = [(i * d) + np.random.permutation(d) for i in range(nbr_stack)] + P = np.random.permutation(np.eye(N=nbr_stack * d))[idx].astype(np.float32) return tf.Variable(P, name="P", trainable=False) def H_variable(d): """ - Return an Hadamard matrix into Tensorflow Variable. + Return an Hadamard matrix converted into Tensorflow Variable. d must be a power of two. - :param d: The size of the Hadamard matrix. + :param d: The size of the Hadamard matrix (dimension of the input space). :type d: int :return: tf.Variable object containing the diagonal and not trainable """ @@ -137,22 +137,20 @@ def H_variable(d): return tf.Variable(H, name="H", trainable=False) -def S_variable(d, G_norm, diag=True, trainable=False): +def S_variable(shape, G_norms, trainable=False): """ - Return a scaling diagonal matrix of random values picked from a chi distribution. + Return a scaling matrix of random values picked from a chi distribution. - The values are re-scaled using the norm of the Gaussian Diagonal random matrix G. + The values are re-scaled using the norm of the associated Gaussian random matrix G. The associated Gaussian + vectors are the ones generated by the `G_variable` function. - :param d: The size of the diagonal. - :type d: int - :param G_norm: The norm of the Gaussian Diagonal random matrix G. - :type G_norm: float - :return: tf.Variable object containing the diagonal and not trainable. + :param shape: The shape of the matrix (number of fastfood stacks (v), dimension of the input space (d)) + :type shape: int or tuple of int (tuple size = 2) + :param G_norms: The norms of the associated Gaussian random matrices G. + :type G_norms: np.array of floats + :return: tf.Variable object containing the matrix. """ - if diag: - S = np.diag((1 / G_norm) * scipy.stats.chi.rvs(d, size=d)).astype(np.float32) - else: - S = (1 / G_norm) * scipy.stats.chi.rvs(d, size=d).astype(np.float32) + S = np.multiply((1 / G_norms.reshape((-1, 1))), scipy.stats.chi.rvs(shape[1], size=shape).astype(np.float32)) return tf.Variable(S, name="S", trainable=trainable) @@ -197,8 +195,23 @@ def random_features(conv_out, sigma): return h1_final -def fast_food(conv_out, sigma, nbr_stack=1, diag=True, trainable=False, name="fastfood"): - with tf.name_scope(name + "_diag=" + str(diag) + "_sigma=" + str(sigma)): +def fast_food(conv_out, sigma, nbr_stack=1, trainable=False): + """ + Return a fastfood transform op compatible with tensorflow graph. + + Implementation largely inspired from https://gist.github.com/dougalsutherland/1a3c70e57dd1f64010ab . + + See: + "Fastfood | Approximating Kernel Expansions in Loglinear Time" by + Quoc Le, Tamas Sarl and Alex Smola. + + :param conv_out: the input of the op + :param sigma: bandwith of the gaussian distribution + :param nbr_stack: number of fast food stacks + :param trainable: the diagonal matrices are trainable or not + :return: the output of the fastfood transform + """ + with tf.name_scope("fastfood" + "_sigma-"+str(sigma)): init_dim = np.prod([s.value for s in conv_out.shape if s.value is not None]) final_dim = int(dimensionality_constraints(init_dim)) padding = final_dim - init_dim @@ -206,40 +219,31 @@ def fast_food(conv_out, sigma, nbr_stack=1, diag=True, trainable=False, name="fa paddings = tf.constant([[0, 0], [0, padding]]) conv_out2 = tf.pad(conv_out2, paddings, "CONSTANT") - G, G_norm = G_variable(final_dim, diag=diag, trainable=trainable) + G, G_norm = G_variable((nbr_stack, final_dim), trainable=trainable) tf.summary.histogram("weights G", G) - B = B_variable(final_dim, diag=diag, trainable=trainable) + B = B_variable((nbr_stack, final_dim), trainable=trainable) tf.summary.histogram("weights B", B) H = H_variable(final_dim) tf.summary.histogram("weights H", H) - P = P_variable(final_dim) + P = P_variable(final_dim, nbr_stack) tf.summary.histogram("weights P", P) - S = S_variable(final_dim, G_norm, diag=diag, trainable=trainable) + S = S_variable((nbr_stack, final_dim), G_norm, trainable=trainable) tf.summary.histogram("weights S", S) - if diag: - h_ff1 = tf.matmul(conv_out2, B, name="Bx") - h_ff2 = tf.matmul(h_ff1, H, name="HBx") - h_ff3 = tf.matmul(h_ff2, P, name="PHBx") - h_ff4 = tf.matmul(h_ff3, G, name="GPHBx") - h_ff5 = tf.matmul(h_ff4, H, name="HGPHBx") - - h_ff6 = tf.scalar_mul((1/(sigma * np.sqrt(final_dim))), tf.matmul(h_ff5, S, name="SHGPHBx")) - h_ff7_1 = tf.cos(h_ff6) - h_ff7_2 = tf.sin(h_ff6) - h_ff7 = tf.scalar_mul(tf.sqrt(float(1 / final_dim)), tf.concat([h_ff7_1, h_ff7_2], axis=1)) - - else: - h_ff1 = tf.multiply(conv_out2, B, name="Bx") - h_ff2 = tf.matmul(h_ff1, H, name="HBx") - h_ff3 = tf.matmul(h_ff2, P, name="PHBx") - h_ff4 = tf.multiply(h_ff3, G, name="GPHBx") - h_ff5 = tf.matmul(h_ff4, H, name="HGPHBx") - - h_ff6 = tf.scalar_mul((1/(sigma * np.sqrt(final_dim))), tf.multiply(h_ff5, S, name="SHGPHBx")) - h_ff7_1 = tf.cos(h_ff6) - h_ff7_2 = tf.sin(h_ff6) - h_ff7 = tf.scalar_mul(tf.sqrt(float(1 / final_dim)), tf.concat([h_ff7_1, h_ff7_2], axis=1)) + conv_out2 = tf.reshape(conv_out2, (1, -1, 1, final_dim)) + h_ff1 = tf.multiply(conv_out2, B, name="Bx") + h_ff1 = tf.reshape(h_ff1, (-1, final_dim)) + h_ff2 = tf.matmul(h_ff1, H, name="HBx") + h_ff2 = tf.reshape(h_ff2, (-1, final_dim * nbr_stack)) + h_ff3 = tf.matmul(h_ff2, P, name="PHBx") + h_ff4 = tf.multiply(tf.reshape(h_ff3, (-1, final_dim * nbr_stack)), tf.reshape(G, (-1, final_dim * nbr_stack)), name="GPHBx") + h_ff4 = tf.reshape(h_ff4, (-1, final_dim)) + h_ff5 = tf.matmul(h_ff4, H, name="HGPHBx") + + h_ff6 = tf.scalar_mul((1/(sigma * np.sqrt(final_dim))), tf.multiply(tf.reshape(h_ff5, (-1, final_dim * nbr_stack)), tf.reshape(S, (-1, final_dim * nbr_stack)), name="SHGPHBx")) + h_ff7_1 = tf.cos(h_ff6) + h_ff7_2 = tf.sin(h_ff6) + h_ff7 = tf.scalar_mul(tf.sqrt(float(1 / final_dim)), tf.concat([h_ff7_1, h_ff7_2], axis=1)) return h_ff7 @@ -255,15 +259,6 @@ def fully_connected(conv_out): return h_fc1 -def stacked_fastfood(input_, nbr, sigma, diag=False, trainable=False): - l_outputs = [] - for i in range(nbr): - output = fast_food(input_, sigma, diag=diag, trainable=trainable, name="fastfood" + str(i)) - l_outputs.append(output) - outputs_stacked = tf.concat(l_outputs, axis=1) - return outputs_stacked - - if __name__ == '__main__': SIGMA = 5.0 print("Sigma = {}".format(SIGMA)) @@ -281,12 +276,10 @@ if __name__ == '__main__': h_conv = convolution(x_image) # h_conv = x # out_fc = fully_connected(h_conv) # 95% accuracy - # out_fc = tf.nn.relu(fast_food(h_conv, SIGMA)) # 83% accuracy (conv) | 56% accuracy (noconv) - # out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, diag=False)) # 84% accuracy (conv) | 59% accuracy (noconv) - # out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, diag=False, trainable=True)) # 84% accuracy (conv) | 59% accuracy (noconv) - # todo: faire une implémentation moins naive: il doit y avoir des blocs dans tf uniquement lorsque j'utilise des matrices - # diagonales, sinon je n'ai besoin que de plusieurs lignes pour la matrice de hadamard - out_fc = tf.nn.relu(stacked_fastfood(h_conv, 2, SIGMA, diag=False, trainable=True)) # 84% accuracy (conv) | 59% accuracy (noconv) + # out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, nbr_stack=1)) # 83% accuracy (conv) | 56% accuracy (noconv) + out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, nbr_stack=2)) + # out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, nbr_stack=2, trainable=True)) + # out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, trainable=True)) # 84% accuracy (conv) | 59% accuracy (noconv) # out_fc = fast_food(h_conv, SIGMA, diag=True, trainable=True) # 84% accuracy (conv) | 59% accuracy (noconv) # out_fc = random_features(h_conv, SIGMA) # 82% accuracy (conv) | 47% accuracy (noconv) @@ -333,7 +326,7 @@ if __name__ == '__main__': sess.run(init) # actual learning started = t.time() - for i in range(2000): + for i in range(20000): batch = mnist.train.next_batch(64) feed_dict = {x: batch[0], y_: batch[1], keep_prob: 0.5} # le _ est pour capturer le retour de "train_optimizer" qu'il faut appeler -- GitLab