Skip to content
Snippets Groups Projects
Commit 8c5a8b04 authored by Luc Giffon's avatar Luc Giffon
Browse files

example implementation of a deep nn on cifar dataset without distortion

parent 743e916d
Branches
Tags
No related merge requests found
"""
Convolutional Neural Netwok implementation in tensorflow whith fully connected layers.
The neural network is ran against the mnist dataset and we can see an example of distortion of input in the case
where the input comes from memory.
"""
import tensorflow as tf
import numpy as np
import skluc.mldatasets as dataset
from skluc.neural_networks import get_next_batch, inference_cifar10, batch_generator
import matplotlib.pyplot as plt
tf.logging.set_verbosity(tf.logging.ERROR)
import time as t
# Preparing the dataset #########################
val_size = 5000
cifar10 = dataset.Cifar10Dataset(validation_size=val_size)
cifar10.load()
cifar10.to_image()
cifar10.to_one_hot()
cifar10.normalize()
cifar10.data_astype(np.float32)
cifar10.labels_astype(np.float32)
X_train, Y_train = cifar10.train
X_val, Y_val = cifar10.validation
X_test, Y_test = cifar10.test
plt.imshow(X_train[0])
plt.show()
#################################################
def main():
with tf.Graph().as_default():
input_dim, output_dim = X_train.shape[1], Y_train.shape[1]
x = tf.placeholder(tf.float32, shape=[None, cifar10.HEIGHT, cifar10.WIDTH, cifar10.DEPTH], name="x")
y_ = tf.placeholder(tf.float32, shape=[None, output_dim], name="labels")
# side size is width or height of the images
x_image = x
tf.summary.image("digit", x_image, max_outputs=3)
# this is how we apply distortion but it is not used afterward
x_image_distorded = tf.image.random_brightness(x_image, max_delta=30)
tf.summary.image("digit_distorded", x_image_distorded, max_outputs=3)
y_conv, keep_prob = inference_cifar10(x_image, output_dim)
# calcul de la loss
with tf.name_scope("xent"):
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv, name="xentropy"),
name="xentropy_mean")
tf.summary.scalar('loss-xent', cross_entropy)
# calcul du gradient
with tf.name_scope("train"):
global_step = tf.Variable(0, name="global_step", trainable=False)
train_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(cross_entropy, global_step=global_step)
# calcul de l'accuracy
with tf.name_scope("accuracy"):
predictions = tf.argmax(y_conv, 1)
correct_prediction = tf.equal(predictions, tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar("accuracy", accuracy)
merged_summary = tf.summary.merge_all()
init = tf.global_variables_initializer()
# Create a session for running Ops on the Graph.
sess = tf.Session()
# Instantiate a SummaryWriter to output summaries and the Graph.
summary_writer = tf.summary.FileWriter("results_fc_distorded")
summary_writer.add_graph(sess.graph)
# Initialize all Variable objects
sess.run(init)
# actual learning
started = t.time()
feed_dict_val = {x: X_val, y_: Y_val, keep_prob: 1.0}
for i in range(50):
j = 0
for X_batch, Y_batch in batch_generator(X_train, Y_train, 64, circle=True):
feed_dict = {x: X_batch, y_: Y_batch, keep_prob: 0.5}
# le _ est pour capturer le retour de "train_optimizer" qu'il faut appeler
# pour calculer le gradient mais dont l'output ne nous interesse pas
_, loss = sess.run([train_optimizer, cross_entropy], feed_dict=feed_dict)
if j % 100 == 0:
print('step {}, loss {} (with dropout)'.format(i, loss))
r_accuracy = sess.run([accuracy], feed_dict=feed_dict_val)
print("accuracy: {} on validation set (without dropout).".format(r_accuracy))
summary_str = sess.run(merged_summary, feed_dict=feed_dict)
summary_writer.add_summary(summary_str, i)
j += 1
stoped = t.time()
accuracy, preds = sess.run([accuracy, predictions], feed_dict={
x: X_test, y_: Y_test, keep_prob: 1.0})
print('test accuracy %g' % accuracy)
np.set_printoptions(threshold=np.nan)
print("Prediction sample: " + str(preds[:50]))
print("Actual values: " + str(np.argmax(Y_test[:50], axis=1)))
print("Elapsed time: %.4f s" % (stoped - started))
if __name__ == '__main__':
main()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment