Skip to content
Snippets Groups Projects
Commit b30e8f3f authored by Luc Giffon's avatar Luc Giffon
Browse files

ajout de la version avec cifar10 [WIP]

parent 02ce1166
No related branches found
No related tags found
No related merge requests found
import tensorflow as tf
import numpy as np
from sklearn.preprocessing import LabelBinarizer
import skluc.mldatasets as dataset
from fasfood_layer import fast_food
import matplotlib.pyplot as plt
IMAGE_SIZE = 24
enc = LabelBinarizer()
cifar = dataset.Cifar10Dataset()
cifar_d = cifar.load()
cifar_d = cifar.to_image()
X_train, Y_train = cifar_d["train"]
X_test, Y_test = cifar_d["test"]
X_train = np.array(X_train / 255)
enc.fit(Y_train)
Y_train = np.array(enc.transform(Y_train))
X_test = np.array(X_test / 255)
Y_test = np.array(enc.transform(Y_test))
X_train = X_train.astype(np.float32)
permut = np.random.permutation(X_train.shape[0])
val_size = 5000
X_val = X_train[permut[:val_size]]
X_train = X_train[permut[val_size:]]
Y_val = Y_train[permut[:val_size]]
Y_train = Y_train[permut[val_size:]]
X_test = X_test.astype(np.float32)
Y_train = Y_train.astype(np.float32)
Y_test = Y_test.astype(np.float32)
def distorded_inputs(image_tensor):
height = IMAGE_SIZE
width = IMAGE_SIZE
distorted_image = tf.random_crop(image_tensor, [height, width, 3])
distorted_image = tf.image.random_flip_left_right(distorted_image)
distorted_image = tf.image.random_brightness(distorted_image,
max_delta=63)
distorted_image = tf.image.random_contrast(distorted_image,
lower=0.2, upper=1.8)
float_image = tf.image.per_image_standardization(distorted_image)
return float_image
if __name__ == '__main__':
SIGMA = 5.0
print("Sigma = {}".format(SIGMA))
with tf.Graph().as_default():
output_dim = Y_train.shape[1]
input_dim = X_train.shape[1:]
x_image = tf.placeholder(tf.float32, shape=[None, *input_dim], name="x_image")
y_ = tf.placeholder(tf.float32, shape=[None, output_dim], name="labels")
tf.summary.image("cifarimage", x_image, max_outputs=10)
dist_x_images = distorded_inputs(x_image)
tf.summary.image("cifarimagedistorded", dist_x_images, max_outputs=10)
# out = fast_food(x_image, SIGMA)
merged_summary = tf.summary.merge_all()
init = tf.global_variables_initializer()
sess = tf.Session()
summary_writer = tf.summary.FileWriter("cifar")
feed_dict = {x_image: X_train[:10], y_: Y_train[:10]}
summary = sess.run([merged_summary], feed_dict=feed_dict)
summary_writer.add_summary(summary[0])
summary_writer.close()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment