Skip to content
Snippets Groups Projects
Commit 9dd0f2c9 authored by Luc Giffon's avatar Luc Giffon
Browse files

move get_next_batch to skluc.ml_datasets

parent 6229eff5
No related branches found
No related tags found
No related merge requests found
......@@ -13,7 +13,7 @@ Zichao Yang, Marcin Moczulski, Misha Denil, Nando de Freitas, Alex Smola, Le Son
import tensorflow as tf
import numpy as np
import skluc.mldatasets as dataset
from skluc.neural_networks import bias_variable, weight_variable, conv2d, max_pool_2x2
from skluc.neural_networks import bias_variable, weight_variable, conv2d, max_pool_2x2, get_next_batch
tf.logging.set_verbosity(tf.logging.ERROR)
......@@ -111,27 +111,6 @@ def fully_connected(conv_out):
return h_fc1
def get_next_batch(full_set, batch_nbr, batch_size):
"""
Return the next batch of a dataset.
This function assumes that all the previous batches of this dataset have been taken with the same size.
:param full_set: the full dataset from which the batch will be taken
:param batch_nbr: the number of the batch
:param batch_size: the size of the batch
:return:
"""
index_start = (batch_nbr * batch_size) % full_set.shape[0]
index_stop = ((batch_nbr + 1) * batch_size) % full_set.shape[0]
if index_stop > index_start:
return full_set[index_start:index_stop]
else:
part1 = full_set[index_start:]
part2 = full_set[:index_stop]
return np.vstack((part1, part2))
if __name__ == '__main__':
SIGMA = 5.0
print("Sigma = {}".format(SIGMA))
......@@ -150,11 +129,11 @@ if __name__ == '__main__':
# Representation layer
h_conv = convolution_mnist(x_image)
# h_conv = x
out_fc = fully_connected(h_conv) # 95% accuracy
# out_fc = fully_connected(h_conv) # 95% accuracy
# out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, nbr_stack=1)) # 83% accuracy (conv) | 56% accuracy (noconv)
# out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, nbr_stack=2))
# out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, nbr_stack=2, trainable=True))
# out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, trainable=True)) # 84% accuracy (conv) | 59% accuracy (noconv)
out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, trainable=True)) # 84% accuracy (conv) | 59% accuracy (noconv)
# out_fc = fast_food(h_conv, SIGMA, diag=True, trainable=True) # 84% accuracy (conv) | 59% accuracy (noconv)
# out_fc = random_features(h_conv, SIGMA) # 82% accuracy (conv) | 47% accuracy (noconv)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment