Skip to content
Snippets Groups Projects
Commit f89c3394 authored by Luc Giffon's avatar Luc Giffon
Browse files

create tensorflow directory + kernel approximations layers for tensorflow

parent 8c5a8b04
Branches
No related tags found
No related merge requests found
from skluc.fastfood_approximation import Fastfood as Fastfood
\ No newline at end of file
from skluc.tensorflow.kernel_approximation.fastfood_layer import fastfood_layer as fastfood_layer
from skluc.tensorflow.kernel_approximation.nystrom_approx import nystrom_layer as nystrom_layer
import numpy as np
import tensorflow as tf
import scipy.linalg
import scipy.stats
# --- Fast Food Naive --- #
def G_variable(shape, trainable=False):
"""
Return a Gaussian Random matrix converted into Tensorflow Variable.
:param shape: The shape of the matrix (number of fastfood stacks (v), dimension of the input space (d))
:type shape: int or tuple of int (tuple size = 2)
:return: tf.Variable object containing the matrix, The norm2 of each line (np.array of float)
"""
assert type(shape) == int or (type(shape) == tuple and len(shape) == 2)
G = np.random.normal(size=shape).astype(np.float32)
G_norms = np.linalg.norm(G, ord=2, axis=1)
return tf.Variable(G, name="G", trainable=trainable), G_norms
def B_variable(shape, trainable=False):
"""
Return a random matrix of -1 and 1 picked uniformly and converted into Tensorflow Variable.
:param shape: The shape of the matrix (number of fastfood stacks (v), dimension of the input space (d))
:type shape: int or tuple of int (tuple size = 2)
:return: tf.Variable object containing the matrix
"""
assert type(shape) == int or (type(shape) == tuple and len(shape) == 2)
B = np.random.choice([-1, 1], size=shape, replace=True).astype(np.float32)
return tf.Variable(B, name="B", trainable=trainable)
def P_variable(d, nbr_stack):
"""
Return a permutation matrix converted into Tensorflow Variable.
:param d: The width of the matrix (dimension of the input space)
:type d: int
:param nbr_stack: The height of the matrix (nbr_stack x d is the dimension of the output space)
:type nbr_stack: int
:return: tf.Variable object containing the matrix
"""
idx = np.hstack([(i * d) + np.random.permutation(d) for i in range(nbr_stack)])
P = np.random.permutation(np.eye(N=nbr_stack * d))[idx].astype(np.float32)
return tf.Variable(P, name="P", trainable=False)
def H_variable(d):
"""
Return an Hadamard matrix converted into Tensorflow Variable.
d must be a power of two.
:param d: The size of the Hadamard matrix (dimension of the input space).
:type d: int
:return: tf.Variable object containing the diagonal and not trainable
"""
H = build_hadamard(d).astype(np.float32)
return tf.Variable(H, name="H", trainable=False)
def S_variable(shape, G_norms, trainable=False):
"""
Return a scaling matrix of random values picked from a chi distribution.
The values are re-scaled using the norm of the associated Gaussian random matrix G. The associated Gaussian
vectors are the ones generated by the `G_variable` function.
:param shape: The shape of the matrix (number of fastfood stacks (v), dimension of the input space (d))
:type shape: int or tuple of int (tuple size = 2)
:param G_norms: The norms of the associated Gaussian random matrices G.
:type G_norms: np.array of floats
:return: tf.Variable object containing the matrix.
"""
S = np.multiply((1 / G_norms.reshape((-1, 1))), scipy.stats.chi.rvs(shape[1], size=shape).astype(np.float32))
return tf.Variable(S, name="S", trainable=trainable)
def fastfood_layer(conv_out, sigma, nbr_stack=1, trainable=False):
"""
Return a fastfood transform op compatible with tensorflow graph.
Implementation largely inspired from https://gist.github.com/dougalsutherland/1a3c70e57dd1f64010ab .
See:
"Fastfood | Approximating Kernel Expansions in Loglinear Time" by
Quoc Le, Tamas Sarl and Alex Smola.
:param conv_out: the input of the op
:param sigma: bandwith of the gaussian distribution
:param nbr_stack: number of fast food stacks
:param trainable: the diagonal matrices are trainable or not
:return: the output of the fastfood transform
"""
with tf.name_scope("fastfood" + "_sigma-"+str(sigma)):
init_dim = np.prod([s.value for s in conv_out.shape if s.value is not None])
final_dim = int(dimensionality_constraints(init_dim))
padding = final_dim - init_dim
conv_out2 = tf.reshape(conv_out, [-1, init_dim])
paddings = tf.constant([[0, 0], [0, padding]])
conv_out2 = tf.pad(conv_out2, paddings, "CONSTANT")
G, G_norm = G_variable((nbr_stack, final_dim), trainable=trainable)
tf.summary.histogram("weights_G", G)
B = B_variable((nbr_stack, final_dim), trainable=trainable)
tf.summary.histogram("weights_B", B)
H = H_variable(final_dim)
tf.summary.histogram("weights_H", H)
P = P_variable(final_dim, nbr_stack)
tf.summary.histogram("weights_P", P)
S = S_variable((nbr_stack, final_dim), G_norm, trainable=trainable)
tf.summary.histogram("weights_S", S)
conv_out2 = tf.reshape(conv_out2, (1, -1, 1, final_dim))
h_ff1 = tf.multiply(conv_out2, B, name="Bx")
h_ff1 = tf.reshape(h_ff1, (-1, final_dim))
h_ff2 = tf.matmul(h_ff1, H, name="HBx")
h_ff2 = tf.reshape(h_ff2, (-1, final_dim * nbr_stack))
h_ff3 = tf.matmul(h_ff2, P, name="PHBx")
h_ff4 = tf.multiply(tf.reshape(h_ff3, (-1, final_dim * nbr_stack)), tf.reshape(G, (-1, final_dim * nbr_stack)), name="GPHBx")
h_ff4 = tf.reshape(h_ff4, (-1, final_dim))
h_ff5 = tf.matmul(h_ff4, H, name="HGPHBx")
h_ff6 = tf.scalar_mul((1/(sigma * np.sqrt(final_dim))), tf.multiply(tf.reshape(h_ff5, (-1, final_dim * nbr_stack)), tf.reshape(S, (-1, final_dim * nbr_stack)), name="SHGPHBx"))
h_ff7_1 = tf.cos(h_ff6)
h_ff7_2 = tf.sin(h_ff6)
h_ff7 = tf.scalar_mul(tf.sqrt(float(1 / final_dim)), tf.concat([h_ff7_1, h_ff7_2], axis=1))
return h_ff7
# --- Hadamard utils --- #
def dimensionality_constraints(d):
"""
Enforce d to be a power of 2
:param d: the original dimension
:return: the final dimension
"""
if not is_power_of_two(d):
# find d that fulfills 2^l
d = np.power(2, np.floor(np.log2(d)) + 1)
return d
def is_power_of_two(input_integer):
""" Test if an integer is a power of two. """
if input_integer == 1:
return False
return input_integer != 0 and ((input_integer & (input_integer - 1)) == 0)
def build_hadamard(n_neurons):
return scipy.linalg.hadamard(n_neurons)
"""
Convnet with nystrom approximation of the feature map.
"""
import time as t
import tensorflow as tf
import numpy as np
import skluc.mldatasets as dataset
from skluc.tensorflow.utils import get_next_batch, classification_mnist, convolution_mnist, tf_rbf_kernel
tf.logging.set_verbosity(tf.logging.ERROR)
val_size = 5000
mnist = dataset.MnistDataset(validation_size=val_size)
mnist.load()
mnist.to_one_hot()
mnist.normalize()
mnist.data_astype(np.float32)
mnist.labels_astype(np.float32)
X_train, Y_train = mnist.train
X_val, Y_val = mnist.validation
X_test, Y_test = mnist.test
def nystrom_layer(input_x, input_subsample, gamma, output_dim):
nystrom_sample_size = input_subsample.shape[0]
with tf.name_scope("nystrom"):
init_dim = np.prod([s.value for s in input_x.shape[1:] if s.value is not None])
h_conv_flat = tf.reshape(input_x, [-1, init_dim])
h_conv_nystrom_subsample_flat = tf.reshape(input_subsample, [nystrom_sample_size, init_dim])
with tf.name_scope("kernel_vec"):
kernel_vector = tf_rbf_kernel(h_conv_flat, h_conv_nystrom_subsample_flat, gamma=gamma)
# this is the initial formulation given by sklearn
# D = tf.get_variable("D", [nystrom_sample_size,], initializer=tf.random_normal_initializer(stddev=0.1))
# V = tf.get_variable("V", [nystrom_sample_size, nystrom_sample_size],
# initializer=tf.random_normal_initializer(stddev=0.1))
# out_fc = tf.matmul(kernel_vector, tf.matmul(tf.multiply(D, V), tf.transpose(V)))
# this is simpler
W = tf.get_variable("W", [nystrom_sample_size, output_dim],
initializer=tf.random_normal_initializer(stddev=0.1))
out_fc = tf.matmul(kernel_vector, W)
return out_fc
def main():
NYSTROM_SAMPLE_SIZE = 100
X_nystrom = X_train[np.random.permutation(NYSTROM_SAMPLE_SIZE)]
GAMMA = 0.001
print("Gamma = {}".format(GAMMA))
with tf.Graph().as_default():
input_dim, output_dim = X_train.shape[1], Y_train.shape[1]
x = tf.placeholder(tf.float32, shape=[None, input_dim], name="x")
x_nystrom = tf.Variable(X_nystrom, name="nystrom_subsample", trainable=False)
y_ = tf.placeholder(tf.float32, shape=[None, output_dim], name="labels")
# side size is width or height of the images
side_size = int(np.sqrt(input_dim))
x_image = tf.reshape(x, [-1, side_size, side_size, 1])
x_nystrom_image = tf.reshape(x_nystrom, [NYSTROM_SAMPLE_SIZE, side_size, side_size, 1])
tf.summary.image("digit", x_image, max_outputs=3)
# Representation layer
with tf.variable_scope("convolution_mnist") as scope_conv_mnist:
h_conv = convolution_mnist(x_image)
scope_conv_mnist.reuse_variables()
h_conv_nystrom_subsample = convolution_mnist(x_nystrom_image, trainable=False)
out_fc = nystrom_layer(h_conv, h_conv_nystrom_subsample, GAMMA, NYSTROM_SAMPLE_SIZE)
y_conv, keep_prob = classification_mnist(out_fc, output_dim=output_dim)
# # calcul de la loss
with tf.name_scope("xent"):
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv, name="xentropy"),
name="xentropy_mean")
tf.summary.scalar('loss-xent', cross_entropy)
# # calcul du gradient
with tf.name_scope("train"):
global_step = tf.Variable(0, name="global_step", trainable=False)
train_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(cross_entropy, global_step=global_step)
# # calcul de l'accuracy
with tf.name_scope("accuracy"):
predictions = tf.argmax(y_conv, 1)
correct_prediction = tf.equal(predictions, tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar("accuracy", accuracy)
merged_summary = tf.summary.merge_all()
init = tf.global_variables_initializer()
# Create a session for running Ops on the Graph.
sess = tf.Session()
# Instantiate a SummaryWriter to output summaries and the Graph.
summary_writer = tf.summary.FileWriter("results_deepfried_stacked")
summary_writer.add_graph(sess.graph)
# Initialize all Variable objects
sess.run(init)
# actual learning
started = t.time()
feed_dict_val = {x: X_val, y_: Y_val, keep_prob: 1.0}
for i in range(10000):
X_batch = get_next_batch(X_train, i, 64)
Y_batch = get_next_batch(Y_train, i, 64)
feed_dict = {x: X_batch, y_: Y_batch, keep_prob: 0.5}
# le _ est pour capturer le retour de "train_optimizer" qu'il faut appeler
# pour calculer le gradient mais dont l'output ne nous interesse pas
_, loss, y_result, x_exp = sess.run([train_optimizer, cross_entropy, y_conv, x_image], feed_dict=feed_dict)
if i % 100 == 0:
print('step {}, loss {} (with dropout)'.format(i, loss))
r_accuracy = sess.run([accuracy], feed_dict=feed_dict_val)
print("accuracy: {} on validation set (without dropout).".format(r_accuracy))
summary_str = sess.run(merged_summary, feed_dict=feed_dict)
summary_writer.add_summary(summary_str, i)
stoped = t.time()
accuracy, preds = sess.run([accuracy, predictions], feed_dict={
x: X_test, y_: Y_test, keep_prob: 1.0})
print('test accuracy %g' % accuracy)
np.set_printoptions(threshold=np.nan)
print("Prediction sample: " + str(preds[:50]))
print("Actual values: " + str(np.argmax(Y_test[:50], axis=1)))
print("Elapsed time: %.4f s" % (stoped - started))
if __name__ == '__main__':
main()
...@@ -76,16 +76,16 @@ def batch_generator(X_train, Y_train, batch_size, circle=False): ...@@ -76,16 +76,16 @@ def batch_generator(X_train, Y_train, batch_size, circle=False):
j += 1 j += 1
def fully_connected(input_op, output_dim, act=None, name_scope="fully_connected"): def fully_connected(input_op, output_dim, act=None, variable_scope="fully_connected"):
""" """
Implement a layer of size Implement a layer of size
:param input_op: :param input_op:
:param output_dim: :param output_dim:
:param act: :param act:
:param name_scope: :param variable_scope:
:return: :return:
""" """
with tf.name_scope(name_scope): with tf.variable_scope(variable_scope):
init_dim = np.prod([s.value for s in input_op.shape if s.value is not None]) init_dim = np.prod([s.value for s in input_op.shape if s.value is not None])
h_pool2_flat = tf.reshape(input_op, [-1, init_dim]) h_pool2_flat = tf.reshape(input_op, [-1, init_dim])
W_fc1 = tf.get_variable("weights", [init_dim, output_dim], initializer=tf.random_normal_initializer(stddev=0.1)) W_fc1 = tf.get_variable("weights", [init_dim, output_dim], initializer=tf.random_normal_initializer(stddev=0.1))
...@@ -100,16 +100,17 @@ def fully_connected(input_op, output_dim, act=None, name_scope="fully_connected" ...@@ -100,16 +100,17 @@ def fully_connected(input_op, output_dim, act=None, name_scope="fully_connected"
return result return result
def conv_relu_pool(input_, kernel_shape, bias_shape, pool_size=2, trainable=True): def conv_relu_pool(input_, kernel_shape, bias_shape, pool_size=2, trainable=True, variable_scope="convolution"):
""" """
Generic function for defining a convolutional layer. Generic function for defining a convolutional layer.
:param input_: The input tensor to be convoluted :param input_: The input tensor to be convoluted
:param kernel_shape: The shape of the kernels/filters :param kernel_shape: The shape of the kernels/filters
:param bias_shape: The shape of the bias :param bias_shape: The shape of the bias
:param variable_scope:
:return: The output tensor of the convolution :return: The output tensor of the convolution
""" """
with tf.name_scope("convolution"): with tf.variable_scope(variable_scope):
weights = tf.get_variable("weights", kernel_shape, initializer=tf.random_normal_initializer(stddev=0.1), trainable=trainable) weights = tf.get_variable("weights", kernel_shape, initializer=tf.random_normal_initializer(stddev=0.1), trainable=trainable)
biases = tf.get_variable("biases", bias_shape, initializer=tf.constant_initializer(0.0), trainable=trainable) biases = tf.get_variable("biases", bias_shape, initializer=tf.constant_initializer(0.0), trainable=trainable)
tf.summary.histogram("weights", weights) tf.summary.histogram("weights", weights)
...@@ -134,6 +135,7 @@ def tf_op(d_feed_dict, ops): ...@@ -134,6 +135,7 @@ def tf_op(d_feed_dict, ops):
init = tf.global_variables_initializer() init = tf.global_variables_initializer()
sess.run([init]) sess.run([init])
sess.run(ops, feed_dict=d_feed_dict) sess.run(ops, feed_dict=d_feed_dict)
sess.close()
def convolution_mnist(input_, trainable=True): def convolution_mnist(input_, trainable=True):
...@@ -253,6 +255,18 @@ def random_features(conv_out, sigma): ...@@ -253,6 +255,18 @@ def random_features(conv_out, sigma):
return h1_final return h1_final
def tf_rbf_kernel(X, Y, gamma):
r1 = tf.reduce_sum(X * X, axis=1)
r1 = tf.reshape(r1, [-1, 1])
r2 = tf.reduce_sum(Y * Y, axis=1)
r2 = tf.reshape(r2, [1, -1])
K = tf.matmul(X, tf.transpose(Y))
K = r1 - 2 * K + r2
K *= -gamma
K = tf.exp(K)
return K
if __name__ == "__main__": if __name__ == "__main__":
X_train = np.arange(12) X_train = np.arange(12)
Y_train = np.array(["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p"]) Y_train = np.array(["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p"])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment