Skip to content
Snippets Groups Projects
Commit 9dd0f2c9 authored by Luc Giffon's avatar Luc Giffon
Browse files

move get_next_batch to skluc.ml_datasets

parent 6229eff5
No related branches found
No related tags found
No related merge requests found
...@@ -13,7 +13,7 @@ Zichao Yang, Marcin Moczulski, Misha Denil, Nando de Freitas, Alex Smola, Le Son ...@@ -13,7 +13,7 @@ Zichao Yang, Marcin Moczulski, Misha Denil, Nando de Freitas, Alex Smola, Le Son
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
import skluc.mldatasets as dataset import skluc.mldatasets as dataset
from skluc.neural_networks import bias_variable, weight_variable, conv2d, max_pool_2x2 from skluc.neural_networks import bias_variable, weight_variable, conv2d, max_pool_2x2, get_next_batch
tf.logging.set_verbosity(tf.logging.ERROR) tf.logging.set_verbosity(tf.logging.ERROR)
...@@ -111,27 +111,6 @@ def fully_connected(conv_out): ...@@ -111,27 +111,6 @@ def fully_connected(conv_out):
return h_fc1 return h_fc1
def get_next_batch(full_set, batch_nbr, batch_size):
"""
Return the next batch of a dataset.
This function assumes that all the previous batches of this dataset have been taken with the same size.
:param full_set: the full dataset from which the batch will be taken
:param batch_nbr: the number of the batch
:param batch_size: the size of the batch
:return:
"""
index_start = (batch_nbr * batch_size) % full_set.shape[0]
index_stop = ((batch_nbr + 1) * batch_size) % full_set.shape[0]
if index_stop > index_start:
return full_set[index_start:index_stop]
else:
part1 = full_set[index_start:]
part2 = full_set[:index_stop]
return np.vstack((part1, part2))
if __name__ == '__main__': if __name__ == '__main__':
SIGMA = 5.0 SIGMA = 5.0
print("Sigma = {}".format(SIGMA)) print("Sigma = {}".format(SIGMA))
...@@ -150,11 +129,11 @@ if __name__ == '__main__': ...@@ -150,11 +129,11 @@ if __name__ == '__main__':
# Representation layer # Representation layer
h_conv = convolution_mnist(x_image) h_conv = convolution_mnist(x_image)
# h_conv = x # h_conv = x
out_fc = fully_connected(h_conv) # 95% accuracy # out_fc = fully_connected(h_conv) # 95% accuracy
# out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, nbr_stack=1)) # 83% accuracy (conv) | 56% accuracy (noconv) # out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, nbr_stack=1)) # 83% accuracy (conv) | 56% accuracy (noconv)
# out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, nbr_stack=2)) # out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, nbr_stack=2))
# out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, nbr_stack=2, trainable=True)) # out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, nbr_stack=2, trainable=True))
# out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, trainable=True)) # 84% accuracy (conv) | 59% accuracy (noconv) out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, trainable=True)) # 84% accuracy (conv) | 59% accuracy (noconv)
# out_fc = fast_food(h_conv, SIGMA, diag=True, trainable=True) # 84% accuracy (conv) | 59% accuracy (noconv) # out_fc = fast_food(h_conv, SIGMA, diag=True, trainable=True) # 84% accuracy (conv) | 59% accuracy (noconv)
# out_fc = random_features(h_conv, SIGMA) # 82% accuracy (conv) | 47% accuracy (noconv) # out_fc = random_features(h_conv, SIGMA) # 82% accuracy (conv) | 47% accuracy (noconv)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment