Skip to content
Snippets Groups Projects
Commit 1aabf81f authored by Luc Giffon's avatar Luc Giffon
Browse files

add a separate module for fastfoodlayer

parent 4a92670e
No related branches found
No related tags found
No related merge requests found
......@@ -12,14 +12,13 @@ Zichao Yang, Marcin Moczulski, Misha Denil, Nando de Freitas, Alex Smola, Le Son
import tensorflow as tf
import numpy as np
import scipy.linalg
import scipy.stats
import time as t
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
from fasfood_layer import fast_food
# --- Usual functions --- #
......@@ -79,107 +78,6 @@ def random_biases(shape):
return tf.Variable(b, name="random_biase", trainable=False)
# --- Fast Food Naive --- #
def G_variable(shape, trainable=False):
"""
Return a Gaussian Random matrix converted into Tensorflow Variable.
:param shape: The shape of the matrix (number of fastfood stacks (v), dimension of the input space (d))
:type shape: int or tuple of int (tuple size = 2)
:return: tf.Variable object containing the matrix, The norm2 of each line (np.array of float)
"""
assert type(shape) == int or (type(shape) == tuple and len(shape) == 2)
G = np.random.normal(size=shape).astype(np.float32)
G_norms = np.linalg.norm(G, ord=2, axis=1)
return tf.Variable(G, name="G", trainable=trainable), G_norms
def B_variable(shape, trainable=False):
"""
Return a random matrix of -1 and 1 picked uniformly and converted into Tensorflow Variable.
:param shape: The shape of the matrix (number of fastfood stacks (v), dimension of the input space (d))
:type shape: int or tuple of int (tuple size = 2)
:return: tf.Variable object containing the matrix
"""
assert type(shape) == int or (type(shape) == tuple and len(shape) == 2)
B = np.random.choice([-1, 1], size=shape, replace=True).astype(np.float32)
return tf.Variable(B, name="B", trainable=trainable)
def P_variable(d, nbr_stack):
"""
Return a permutation matrix converted into Tensorflow Variable.
:param d: The width of the matrix (dimension of the input space)
:type d: int
:param nbr_stack: The height of the matrix (nbr_stack x d is the dimension of the output space)
:type nbr_stack: int
:return: tf.Variable object containing the matrix
"""
idx = np.hstack([(i * d) + np.random.permutation(d) for i in range(nbr_stack)])
P = np.random.permutation(np.eye(N=nbr_stack * d))[idx].astype(np.float32)
return tf.Variable(P, name="P", trainable=False)
def H_variable(d):
"""
Return an Hadamard matrix converted into Tensorflow Variable.
d must be a power of two.
:param d: The size of the Hadamard matrix (dimension of the input space).
:type d: int
:return: tf.Variable object containing the diagonal and not trainable
"""
H = build_hadamard(d).astype(np.float32)
return tf.Variable(H, name="H", trainable=False)
def S_variable(shape, G_norms, trainable=False):
"""
Return a scaling matrix of random values picked from a chi distribution.
The values are re-scaled using the norm of the associated Gaussian random matrix G. The associated Gaussian
vectors are the ones generated by the `G_variable` function.
:param shape: The shape of the matrix (number of fastfood stacks (v), dimension of the input space (d))
:type shape: int or tuple of int (tuple size = 2)
:param G_norms: The norms of the associated Gaussian random matrices G.
:type G_norms: np.array of floats
:return: tf.Variable object containing the matrix.
"""
S = np.multiply((1 / G_norms.reshape((-1, 1))), scipy.stats.chi.rvs(shape[1], size=shape).astype(np.float32))
return tf.Variable(S, name="S", trainable=trainable)
# --- Hadamard utils --- #
def dimensionality_constraints(d):
"""
Enforce d to be a power of 2
:param d: the original dimension
:return: the final dimension
"""
if not is_power_of_two(d):
# find d that fulfills 2^l
d = np.power(2, np.floor(np.log2(d)) + 1)
return d
def is_power_of_two(input_integer):
""" Test if an integer is a power of two. """
if input_integer == 1:
return False
return input_integer != 0 and ((input_integer & (input_integer - 1)) == 0)
def build_hadamard(n_neurons):
return scipy.linalg.hadamard(n_neurons)
# --- Representation Layer --- #
def random_features(conv_out, sigma):
......@@ -195,58 +93,6 @@ def random_features(conv_out, sigma):
return h1_final
def fast_food(conv_out, sigma, nbr_stack=1, trainable=False):
"""
Return a fastfood transform op compatible with tensorflow graph.
Implementation largely inspired from https://gist.github.com/dougalsutherland/1a3c70e57dd1f64010ab .
See:
"Fastfood | Approximating Kernel Expansions in Loglinear Time" by
Quoc Le, Tamas Sarl and Alex Smola.
:param conv_out: the input of the op
:param sigma: bandwith of the gaussian distribution
:param nbr_stack: number of fast food stacks
:param trainable: the diagonal matrices are trainable or not
:return: the output of the fastfood transform
"""
with tf.name_scope("fastfood" + "_sigma-"+str(sigma)):
init_dim = np.prod([s.value for s in conv_out.shape if s.value is not None])
final_dim = int(dimensionality_constraints(init_dim))
padding = final_dim - init_dim
conv_out2 = tf.reshape(conv_out, [-1, init_dim])
paddings = tf.constant([[0, 0], [0, padding]])
conv_out2 = tf.pad(conv_out2, paddings, "CONSTANT")
G, G_norm = G_variable((nbr_stack, final_dim), trainable=trainable)
tf.summary.histogram("weights G", G)
B = B_variable((nbr_stack, final_dim), trainable=trainable)
tf.summary.histogram("weights B", B)
H = H_variable(final_dim)
tf.summary.histogram("weights H", H)
P = P_variable(final_dim, nbr_stack)
tf.summary.histogram("weights P", P)
S = S_variable((nbr_stack, final_dim), G_norm, trainable=trainable)
tf.summary.histogram("weights S", S)
conv_out2 = tf.reshape(conv_out2, (1, -1, 1, final_dim))
h_ff1 = tf.multiply(conv_out2, B, name="Bx")
h_ff1 = tf.reshape(h_ff1, (-1, final_dim))
h_ff2 = tf.matmul(h_ff1, H, name="HBx")
h_ff2 = tf.reshape(h_ff2, (-1, final_dim * nbr_stack))
h_ff3 = tf.matmul(h_ff2, P, name="PHBx")
h_ff4 = tf.multiply(tf.reshape(h_ff3, (-1, final_dim * nbr_stack)), tf.reshape(G, (-1, final_dim * nbr_stack)), name="GPHBx")
h_ff4 = tf.reshape(h_ff4, (-1, final_dim))
h_ff5 = tf.matmul(h_ff4, H, name="HGPHBx")
h_ff6 = tf.scalar_mul((1/(sigma * np.sqrt(final_dim))), tf.multiply(tf.reshape(h_ff5, (-1, final_dim * nbr_stack)), tf.reshape(S, (-1, final_dim * nbr_stack)), name="SHGPHBx"))
h_ff7_1 = tf.cos(h_ff6)
h_ff7_2 = tf.sin(h_ff6)
h_ff7 = tf.scalar_mul(tf.sqrt(float(1 / final_dim)), tf.concat([h_ff7_1, h_ff7_2], axis=1))
return h_ff7
def fully_connected(conv_out):
with tf.name_scope("fc_1"):
h_pool2_flat = tf.reshape(conv_out, [-1, 7 * 7 * 64])
......@@ -264,6 +110,7 @@ def mnist_dims():
output_dim = int(mnist.train.labels.shape[1])
return input_dim, output_dim
if __name__ == '__main__':
SIGMA = 5.0
print("Sigma = {}".format(SIGMA))
......
import numpy as np
import tensorflow as tf
import scipy.linalg
import scipy.stats
# --- Fast Food Naive --- #
def G_variable(shape, trainable=False):
"""
Return a Gaussian Random matrix converted into Tensorflow Variable.
:param shape: The shape of the matrix (number of fastfood stacks (v), dimension of the input space (d))
:type shape: int or tuple of int (tuple size = 2)
:return: tf.Variable object containing the matrix, The norm2 of each line (np.array of float)
"""
assert type(shape) == int or (type(shape) == tuple and len(shape) == 2)
G = np.random.normal(size=shape).astype(np.float32)
G_norms = np.linalg.norm(G, ord=2, axis=1)
return tf.Variable(G, name="G", trainable=trainable), G_norms
def B_variable(shape, trainable=False):
"""
Return a random matrix of -1 and 1 picked uniformly and converted into Tensorflow Variable.
:param shape: The shape of the matrix (number of fastfood stacks (v), dimension of the input space (d))
:type shape: int or tuple of int (tuple size = 2)
:return: tf.Variable object containing the matrix
"""
assert type(shape) == int or (type(shape) == tuple and len(shape) == 2)
B = np.random.choice([-1, 1], size=shape, replace=True).astype(np.float32)
return tf.Variable(B, name="B", trainable=trainable)
def P_variable(d, nbr_stack):
"""
Return a permutation matrix converted into Tensorflow Variable.
:param d: The width of the matrix (dimension of the input space)
:type d: int
:param nbr_stack: The height of the matrix (nbr_stack x d is the dimension of the output space)
:type nbr_stack: int
:return: tf.Variable object containing the matrix
"""
idx = np.hstack([(i * d) + np.random.permutation(d) for i in range(nbr_stack)])
P = np.random.permutation(np.eye(N=nbr_stack * d))[idx].astype(np.float32)
return tf.Variable(P, name="P", trainable=False)
def H_variable(d):
"""
Return an Hadamard matrix converted into Tensorflow Variable.
d must be a power of two.
:param d: The size of the Hadamard matrix (dimension of the input space).
:type d: int
:return: tf.Variable object containing the diagonal and not trainable
"""
H = build_hadamard(d).astype(np.float32)
return tf.Variable(H, name="H", trainable=False)
def S_variable(shape, G_norms, trainable=False):
"""
Return a scaling matrix of random values picked from a chi distribution.
The values are re-scaled using the norm of the associated Gaussian random matrix G. The associated Gaussian
vectors are the ones generated by the `G_variable` function.
:param shape: The shape of the matrix (number of fastfood stacks (v), dimension of the input space (d))
:type shape: int or tuple of int (tuple size = 2)
:param G_norms: The norms of the associated Gaussian random matrices G.
:type G_norms: np.array of floats
:return: tf.Variable object containing the matrix.
"""
S = np.multiply((1 / G_norms.reshape((-1, 1))), scipy.stats.chi.rvs(shape[1], size=shape).astype(np.float32))
return tf.Variable(S, name="S", trainable=trainable)
def fast_food(conv_out, sigma, nbr_stack=1, trainable=False):
"""
Return a fastfood transform op compatible with tensorflow graph.
Implementation largely inspired from https://gist.github.com/dougalsutherland/1a3c70e57dd1f64010ab .
See:
"Fastfood | Approximating Kernel Expansions in Loglinear Time" by
Quoc Le, Tamas Sarl and Alex Smola.
:param conv_out: the input of the op
:param sigma: bandwith of the gaussian distribution
:param nbr_stack: number of fast food stacks
:param trainable: the diagonal matrices are trainable or not
:return: the output of the fastfood transform
"""
with tf.name_scope("fastfood" + "_sigma-"+str(sigma)):
init_dim = np.prod([s.value for s in conv_out.shape if s.value is not None])
final_dim = int(dimensionality_constraints(init_dim))
padding = final_dim - init_dim
conv_out2 = tf.reshape(conv_out, [-1, init_dim])
paddings = tf.constant([[0, 0], [0, padding]])
conv_out2 = tf.pad(conv_out2, paddings, "CONSTANT")
G, G_norm = G_variable((nbr_stack, final_dim), trainable=trainable)
tf.summary.histogram("weights G", G)
B = B_variable((nbr_stack, final_dim), trainable=trainable)
tf.summary.histogram("weights B", B)
H = H_variable(final_dim)
tf.summary.histogram("weights H", H)
P = P_variable(final_dim, nbr_stack)
tf.summary.histogram("weights P", P)
S = S_variable((nbr_stack, final_dim), G_norm, trainable=trainable)
tf.summary.histogram("weights S", S)
conv_out2 = tf.reshape(conv_out2, (1, -1, 1, final_dim))
h_ff1 = tf.multiply(conv_out2, B, name="Bx")
h_ff1 = tf.reshape(h_ff1, (-1, final_dim))
h_ff2 = tf.matmul(h_ff1, H, name="HBx")
h_ff2 = tf.reshape(h_ff2, (-1, final_dim * nbr_stack))
h_ff3 = tf.matmul(h_ff2, P, name="PHBx")
h_ff4 = tf.multiply(tf.reshape(h_ff3, (-1, final_dim * nbr_stack)), tf.reshape(G, (-1, final_dim * nbr_stack)), name="GPHBx")
h_ff4 = tf.reshape(h_ff4, (-1, final_dim))
h_ff5 = tf.matmul(h_ff4, H, name="HGPHBx")
h_ff6 = tf.scalar_mul((1/(sigma * np.sqrt(final_dim))), tf.multiply(tf.reshape(h_ff5, (-1, final_dim * nbr_stack)), tf.reshape(S, (-1, final_dim * nbr_stack)), name="SHGPHBx"))
h_ff7_1 = tf.cos(h_ff6)
h_ff7_2 = tf.sin(h_ff6)
h_ff7 = tf.scalar_mul(tf.sqrt(float(1 / final_dim)), tf.concat([h_ff7_1, h_ff7_2], axis=1))
return h_ff7
# --- Hadamard utils --- #
def dimensionality_constraints(d):
"""
Enforce d to be a power of 2
:param d: the original dimension
:return: the final dimension
"""
if not is_power_of_two(d):
# find d that fulfills 2^l
d = np.power(2, np.floor(np.log2(d)) + 1)
return d
def is_power_of_two(input_integer):
""" Test if an integer is a power of two. """
if input_integer == 1:
return False
return input_integer != 0 and ((input_integer & (input_integer - 1)) == 0)
def build_hadamard(n_neurons):
return scipy.linalg.hadamard(n_neurons)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment