diff --git a/main/deepfriedConvnetMnist.py b/main/deepfriedConvnetMnist.py
index 7f7c98280800c202087414dfac9ec87abab72fbb..2ebe4f62f8c344086ca96aadcaa7038defe6b1a1 100644
--- a/main/deepfriedConvnetMnist.py
+++ b/main/deepfriedConvnetMnist.py
@@ -12,11 +12,28 @@ Zichao Yang, Marcin Moczulski, Misha Denil, Nando de Freitas, Alex Smola, Le Son
 
 import tensorflow as tf
 import numpy as np
+import skluc.mldatasets as dataset
+
+tf.logging.set_verbosity(tf.logging.ERROR)
 
 import time as t
 
-from tensorflow.examples.tutorials.mnist import input_data
-mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
+from sklearn.preprocessing import LabelBinarizer
+
+enc = LabelBinarizer()
+mnist = dataset.MnistDataset()
+mnist = mnist.load()
+X_train, Y_train = mnist["train"]
+X_train = np.array(X_train / 255)
+enc.fit(Y_train)
+Y_train = np.array(enc.transform(Y_train))
+X_test, Y_test = mnist["test"]
+X_test = np.array(X_test / 255)
+Y_test = np.array(enc.transform(Y_test))
+X_train = X_train.astype(np.float32)
+X_test = X_test.astype(np.float32)
+Y_train = Y_train.astype(np.float32)
+Y_test = Y_test.astype(np.float32)
 
 from fasfood_layer import fast_food
 
@@ -105,10 +122,25 @@ def fully_connected(conv_out):
     return h_fc1
 
 
-def mnist_dims():
-    input_dim = int(mnist.train.images.shape[1])
-    output_dim = int(mnist.train.labels.shape[1])
-    return input_dim, output_dim
+def get_next_batch(full_set, batch_nbr, batch_size):
+    """
+    Return the next batch of a dataset.
+
+    This function assumes that all the previous batches of this dataset have been taken with the same size.
+
+    :param full_set: the full dataset from which the batch will be taken
+    :param batch_nbr: the number of the batch
+    :param batch_size: the size of the batch
+    :return:
+    """
+    index_start = (batch_nbr * batch_size) % full_set.shape[0]
+    index_stop = ((batch_nbr + 1) * batch_size) % full_set.shape[0]
+    if index_stop > index_start:
+        return full_set[index_start:index_stop]
+    else:
+        part1 = full_set[index_start:]
+        part2 = full_set[:index_stop]
+        return np.vstack((part1, part2))
 
 
 if __name__ == '__main__':
@@ -116,8 +148,8 @@ if __name__ == '__main__':
     print("Sigma = {}".format(SIGMA))
 
     with tf.Graph().as_default():
-        # todo parametrize datset
-        input_dim, output_dim = mnist_dims()
+        # todo parametrize dataset
+        input_dim, output_dim = X_train.shape[1], Y_train.shape[1]
 
         x = tf.placeholder(tf.float32, shape=[None, input_dim], name="x")
         y_ = tf.placeholder(tf.float32, shape=[None, output_dim], name="labels")
@@ -181,9 +213,10 @@ if __name__ == '__main__':
         sess.run(init)
         # actual learning
         started = t.time()
-        for i in range(20000):
-            batch = mnist.train.next_batch(64)
-            feed_dict = {x: batch[0], y_: batch[1], keep_prob: 0.5}
+        for i in range(1100):
+            X_batch = get_next_batch(X_train, i, 64)
+            Y_batch = get_next_batch(Y_train, i, 64)
+            feed_dict = {x: X_batch, y_: Y_batch, keep_prob: 0.5}
             # le _ est pour capturer le retour de "train_optimizer" qu'il faut appeler
             # pour calculer le gradient mais dont l'output ne nous interesse pas
             _, loss = sess.run([train_optimizer, cross_entropy], feed_dict=feed_dict)
@@ -191,12 +224,12 @@ if __name__ == '__main__':
                 print('step {}, loss {} (with dropout)'.format(i, loss))
                 summary_str = sess.run(merged_summary, feed_dict=feed_dict)
                 summary_writer.add_summary(summary_str, i)
-        stoped = t.time()
 
+        stoped = t.time()
         accuracy, preds = sess.run([accuracy, predictions], feed_dict={
-            x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})
+            x: X_test, y_: Y_test, keep_prob: 1.0})
         print('test accuracy %g' % accuracy)
         np.set_printoptions(threshold=np.nan)
         print("Prediction sample: " + str(preds[:50]))
-        print("Actual values: " + str(np.argmax(mnist.test.labels[:50], 1)))
+        print("Actual values: " + str(np.argmax(Y_test[:50], axis=1)))
         print("Elapsed time: %.4f s" % (stoped - started))
\ No newline at end of file