From 9dd0f2c97c8d9c7b0881223dbb9ccb68a47eba11 Mon Sep 17 00:00:00 2001
From: Luc Giffon <luc.giffon@lif.univ-mrs.fr>
Date: Tue, 30 Jan 2018 18:18:39 +0100
Subject: [PATCH] move get_next_batch to skluc.ml_datasets

---
 main/deepfriedConvnetMnist.py | 27 +++------------------------
 1 file changed, 3 insertions(+), 24 deletions(-)

diff --git a/main/deepfriedConvnetMnist.py b/main/deepfriedConvnetMnist.py
index 8b5ecb0..c8a566c 100644
--- a/main/deepfriedConvnetMnist.py
+++ b/main/deepfriedConvnetMnist.py
@@ -13,7 +13,7 @@ Zichao Yang, Marcin Moczulski, Misha Denil, Nando de Freitas, Alex Smola, Le Son
 import tensorflow as tf
 import numpy as np
 import skluc.mldatasets as dataset
-from skluc.neural_networks import bias_variable, weight_variable, conv2d, max_pool_2x2
+from skluc.neural_networks import bias_variable, weight_variable, conv2d, max_pool_2x2, get_next_batch
 
 tf.logging.set_verbosity(tf.logging.ERROR)
 
@@ -111,27 +111,6 @@ def fully_connected(conv_out):
     return h_fc1
 
 
-def get_next_batch(full_set, batch_nbr, batch_size):
-    """
-    Return the next batch of a dataset.
-
-    This function assumes that all the previous batches of this dataset have been taken with the same size.
-
-    :param full_set: the full dataset from which the batch will be taken
-    :param batch_nbr: the number of the batch
-    :param batch_size: the size of the batch
-    :return:
-    """
-    index_start = (batch_nbr * batch_size) % full_set.shape[0]
-    index_stop = ((batch_nbr + 1) * batch_size) % full_set.shape[0]
-    if index_stop > index_start:
-        return full_set[index_start:index_stop]
-    else:
-        part1 = full_set[index_start:]
-        part2 = full_set[:index_stop]
-        return np.vstack((part1, part2))
-
-
 if __name__ == '__main__':
     SIGMA = 5.0
     print("Sigma = {}".format(SIGMA))
@@ -150,11 +129,11 @@ if __name__ == '__main__':
         # Representation layer
         h_conv = convolution_mnist(x_image)
         # h_conv = x
-        out_fc = fully_connected(h_conv)  # 95% accuracy
+        # out_fc = fully_connected(h_conv)  # 95% accuracy
         # out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, nbr_stack=1))  # 83% accuracy (conv) | 56% accuracy (noconv)
         # out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, nbr_stack=2))
         # out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, nbr_stack=2, trainable=True))
-        # out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, trainable=True))  # 84% accuracy (conv) | 59% accuracy (noconv)
+        out_fc = tf.nn.relu(fast_food(h_conv, SIGMA, trainable=True))  # 84% accuracy (conv) | 59% accuracy (noconv)
         # out_fc = fast_food(h_conv, SIGMA, diag=True, trainable=True)  # 84% accuracy (conv) | 59% accuracy (noconv)
         # out_fc = random_features(h_conv, SIGMA)  # 82% accuracy (conv) | 47% accuracy (noconv)
 
-- 
GitLab