Skip to content
Snippets Groups Projects
Commit efded584 authored by Luc Giffon's avatar Luc Giffon
Browse files

First commit: implementation with tensorflow of deepfriedconvnet with...

First commit: implementation with tensorflow of deepfriedconvnet with non-adaptative fastfood and no fht - only one stack of fastfood - mnist dataset
parents
No related branches found
No related tags found
No related merge requests found
.idea
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
.static_storage/
.media/
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
"""
Convolutional Neural Netwok implementation in tensorflow whith multiple representations possible after the convolution:
- Fully connected layer
- Random Fourier Features layer
- Fast Food layer where Fast Hadamard Transform has been replaced by dot product with Hadamard matrix.
"""
import tensorflow as tf
import numpy as np
import scipy.linalg
import scipy.stats
import time as t
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# --- Usual functions --- #
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial, name="weights")
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial, name="biases")
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1], padding='SAME')
def convolution(input):
with tf.name_scope("conv_pool_1"):
# 32 is the number of filter we'll use. e.g. the number of different
# shapes this layer is able to recognize
W_conv1 = weight_variable([5, 5, 1, 32])
tf.summary.histogram("weights conv1", W_conv1)
b_conv1 = bias_variable([32])
tf.summary.histogram("biases conv1", b_conv1)
# -1 is here to keep the total size constant (784)
h_conv1 = tf.nn.relu(conv2d(input, W_conv1) + b_conv1)
tf.summary.histogram("act conv1", h_conv1)
h_pool1 = max_pool_2x2(h_conv1)
with tf.name_scope("conv_pool_2"):
W_conv2 = weight_variable([5, 5, 32, 64])
tf.summary.histogram("weights conv2", W_conv2)
b_conv2 = bias_variable([64])
tf.summary.histogram("biases conv2", b_conv2)
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
tf.summary.histogram("act conv2", h_conv2)
h_pool2 = max_pool_2x2(h_conv2)
return h_pool2
# --- Random Fourier Features --- #
def random_variable(shape, sigma):
W = np.random.normal(size=shape, scale=sigma).astype(np.float32)
return tf.Variable(W, name="random_Weights", trainable=False)
def random_biases(shape):
b = np.random.uniform(0, 2 * np.pi, size=shape).astype(np.float32)
return tf.Variable(b, name="random_biase", trainable=False)
# --- Fast Food Naive --- #
def G_variable(d, diag=True):
"""
Return a Gaussian Random diagonal matrix converted into Tensorflow Variable.
:param d: The size of the diagonal
:type d: int
:return: tf.Variable object containing the diagonal and not trainable, The frobenius norm of this diagonal (float)
"""
if diag:
G = np.diag(np.random.normal(size=d)).astype(np.float32)
G_norm = np.linalg.norm(G, ord='fro')
else:
G = np.random.normal(size=d).astype(np.float32)
G_norm = np.linalg.norm(G, ord=2)
print("Norm of G is: {}".format(G_norm))
return tf.Variable(G, name="G", trainable=False), G_norm
def B_variable(d, diag=True):
"""
Return a random diagonal matrix of -1 and 1 picked uniformly into Tensorflow Variable.
:param d: The size of the diagonal
:type d: int
:return: tf.Variable object containing the diagonal and not trainable
"""
if diag:
B = np.diag(np.random.choice([-1, 1], size=d, replace=True)).astype(np.float32)
else:
B = np.random.choice([-1, 1], size=d, replace=True).astype(np.float32)
return tf.Variable(B, name="B", trainable=False)
def P_variable(d):
"""
Return a permutation matrix into Tensorflow Variable.
:param d: The size of the diagonal
:type d: int
:return: tf.Variable object containing the diagonal and not trainable
"""
idx = np.random.permutation(d)
P = np.random.permutation(np.eye(N=d))[idx].astype(np.float32)
return tf.Variable(P, name="P", trainable=False)
def H_variable(d):
"""
Return an Hadamard matrix into Tensorflow Variable.
d must be a power of two.
:param d: The size of the Hadamard matrix.
:type d: int
:return: tf.Variable object containing the diagonal and not trainable
"""
H = build_hadamard(d).astype(np.float32)
return tf.Variable(H, name="H", trainable=False)
def S_variable(d, G_norm, diag=True):
"""
Return a scaling diagonal matrix of random values picked from a chi distribution.
The values are re-scaled using the norm of the Gaussian Diagonal random matrix G.
:param d: The size of the diagonal.
:type d: int
:param G_norm: The norm of the Gaussian Diagonal random matrix G.
:type G_norm: float
:return: tf.Variable object containing the diagonal and not trainable.
"""
if diag:
S = np.diag((1 / G_norm) * scipy.stats.chi.rvs(d, size=d)).astype(np.float32)
else:
S = (1 / G_norm) * scipy.stats.chi.rvs(d, size=d).astype(np.float32)
return tf.Variable(S, name="S", trainable=False)
# --- Hadamard utils --- #
def dimensionality_constraints(d):
"""
Enforce d to be a power of 2
:param d: the original dimension
:return: the final dimension
"""
if not is_power_of_two(d):
# find d that fulfills 2^l
d = np.power(2, np.floor(np.log2(d)) + 1)
return d
def is_power_of_two(input_integer):
""" Test if an integer is a power of two. """
if input_integer == 1:
return False
return input_integer != 0 and ((input_integer & (input_integer - 1)) == 0)
def build_hadamard(n_neurons):
return scipy.linalg.hadamard(n_neurons)
# --- Representation Layer --- #
def random_features(conv_out, sigma):
with tf.name_scope("random_features"):
init_dim = np.prod([s.value for s in conv_out.shape if s.value is not None])
conv_out2 = tf.reshape(conv_out, [-1, init_dim])
W = random_variable((init_dim, init_dim), sigma)
b = random_biases(init_dim)
h1 = tf.matmul(conv_out2, W, name="Wx") + b
h1_cos = tf.cos(h1)
h1_final = tf.scalar_mul(np.sqrt(2.0 / init_dim).astype(np.float32), h1_cos)
return h1_final
def fast_food(conv_out, sigma, diag=True, trainable=False):
# todo use te trainable parameter
with tf.name_scope("fastfood"):
init_dim = np.prod([s.value for s in conv_out.shape if s.value is not None])
final_dim = int(dimensionality_constraints(init_dim))
padding = final_dim - init_dim
conv_out2 = tf.reshape(conv_out, [-1, init_dim])
paddings = tf.constant([[0, 0], [0, padding]])
conv_out2 = tf.pad(conv_out2, paddings, "CONSTANT")
G, G_norm = G_variable(final_dim, diag=diag)
tf.summary.histogram("weights G", G)
B = B_variable(final_dim, diag=diag)
tf.summary.histogram("weights B", B)
H = H_variable(final_dim)
tf.summary.histogram("weights H", H)
P = P_variable(final_dim)
tf.summary.histogram("weights P", P)
S = S_variable(final_dim, G_norm, diag=diag)
tf.summary.histogram("weights S", S)
if diag:
h_ff1 = tf.matmul(conv_out2, B, name="Bx")
h_ff2 = tf.matmul(h_ff1, H, name="HBx")
h_ff3 = tf.matmul(h_ff2, P, name="PHBx")
h_ff4 = tf.matmul(h_ff3, G, name="GPHBx")
h_ff5 = tf.matmul(h_ff4, H, name="HGPHBx")
h_ff6 = tf.scalar_mul((1/(sigma * np.sqrt(final_dim))), tf.matmul(h_ff5, S, name="SHGPHBx"))
h_ff7_1 = tf.cos(h_ff6)
h_ff7_2 = tf.sin(h_ff6)
h_ff7 = tf.scalar_mul(tf.sqrt(float(1 / final_dim)), tf.concat([h_ff7_1, h_ff7_2], axis=1))
else:
h_ff1 = tf.multiply(conv_out2, B, name="Bx")
h_ff2 = tf.matmul(h_ff1, H, name="HBx")
h_ff3 = tf.matmul(h_ff2, P, name="PHBx")
h_ff4 = tf.multiply(h_ff3, G, name="GPHBx")
h_ff5 = tf.matmul(h_ff4, H, name="HGPHBx")
h_ff6 = tf.scalar_mul((1/(sigma * np.sqrt(final_dim))), tf.multiply(h_ff5, S, name="SHGPHBx"))
h_ff7_1 = tf.cos(h_ff6)
h_ff7_2 = tf.sin(h_ff6)
h_ff7 = tf.scalar_mul(tf.sqrt(float(1 / final_dim)), tf.concat([h_ff7_1, h_ff7_2], axis=1))
return h_ff7
def fully_connected(conv_out):
with tf.name_scope("fc_1"):
h_pool2_flat = tf.reshape(conv_out, [-1, 7 * 7 * 64])
W_fc1 = weight_variable([7 * 7 * 64, 4096*2])
b_fc1 = bias_variable([4096*2])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
tf.summary.histogram("weights", W_fc1)
tf.summary.histogram("biases", b_fc1)
return h_fc1
if __name__ == '__main__':
SIGMA = 100.0
print("Sigma = {}".format(SIGMA))
with tf.Graph().as_default():
input_dim = int(mnist.train.images.shape[1])
output_dim = int(mnist.train.labels.shape[1])
side_size = int(np.sqrt(input_dim))
x = tf.placeholder(tf.float32, shape=[None, input_dim], name="x")
y_ = tf.placeholder(tf.float32, shape=[None, output_dim], name="labels")
x_image = tf.reshape(x, [-1, side_size, side_size, 1])
tf.summary.image("digit", x_image, max_outputs=3)
# Representation layer
h_conv = convolution(x_image)
# h_conv = x
# out_fc = fully_connected(h_conv) # 95% accuracy
# out_fc = fast_food(h_conv, SIGMA) # 83% accuracy (conv) | 56% accuracy (noconv)
# out_fc = fast_food(h_conv, SIGMA, diag=False) # 84% accuracy (conv) | 59% accuracy (noconv)
out_fc = random_features(h_conv, SIGMA) # 82% accuracy (conv) | 47% accuracy (noconv)
# classification
with tf.name_scope("fc_2"):
keep_prob = tf.placeholder(tf.float32, name="keep_prob")
h_fc1_drop = tf.nn.dropout(out_fc, keep_prob)
dim = np.prod([s.value for s in h_fc1_drop.shape if s.value is not None])
W_fc2 = weight_variable([dim, 10])
b_fc2 = bias_variable([10])
tf.summary.histogram("weights", W_fc2)
tf.summary.histogram("biases", b_fc2)
y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
# calcul de la loss
with tf.name_scope("xent"):
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv, name="xentropy"),
name="xentropy_mean")
tf.summary.scalar('loss-xent', cross_entropy)
# calcul du gradient
with tf.name_scope("train"):
global_step = tf.Variable(0, name="global_step", trainable=False)
train_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(cross_entropy, global_step=global_step)
# calcul de l'accuracy
with tf.name_scope("accuracy"):
predictions = tf.argmax(y_conv, 1)
correct_prediction = tf.equal(predictions, tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar("accuracy", accuracy)
merged_summary = tf.summary.merge_all()
init = tf.global_variables_initializer()
# Create a session for running Ops on the Graph.
sess = tf.Session()
# Instantiate a SummaryWriter to output summaries and the Graph.
summary_writer = tf.summary.FileWriter("results_deepfried")
summary_writer.add_graph(sess.graph)
# Initialize all Variable objects
sess.run(init)
# actual learning
started = t.time()
for i in range(500):
batch = mnist.train.next_batch(50)
feed_dict = {x: batch[0], y_: batch[1], keep_prob: 0.5}
# le _ est pour capturer le retour de "train_optimizer" qu'il faut appeler
# pour calculer le gradient mais dont l'output ne nous interesse pas
_, loss = sess.run([train_optimizer, cross_entropy], feed_dict=feed_dict)
if i % 100 == 0:
print('step {}, loss {} (with dropout)'.format(i, loss))
summary_str = sess.run(merged_summary, feed_dict=feed_dict)
summary_writer.add_summary(summary_str, i)
stoped = t.time()
accuracy, preds = sess.run([accuracy, predictions], feed_dict={
x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})
print('test accuracy %g' % accuracy)
np.set_printoptions(threshold=np.nan)
print("Prediction sample: " + str(preds[:50]))
print("Actual values: " + str(np.argmax(mnist.test.labels[:50], 1)))
print("Elapsed time: %.4f s" % (stoped - started))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment