Skip to content
Snippets Groups Projects
timeit.py 3.65 KiB
Newer Older
  • Learn to ignore specific revisions
  • import tensorflow as tf
    import numpy as np
    from sklearn.metrics.pairwise import rbf_kernel
    
    import skluc.mldatasets as dataset
    from skluc.neural_networks import fully_connected, get_next_batch, tf_op
    from skluc.utils import time_fct
    from nystrom.nystrom_approx import tf_rbf_kernel, convolution_mnist
    import sklearn as sk
    
    tf.logging.set_verbosity(tf.logging.ERROR)
    
    # Preparing the dataset #########################
    
    mnist = dataset.MnistDataset()
    mnist = mnist.load()
    X_train, _ = mnist["train"]
    X_train = np.array(X_train / 255)
    X_train = X_train.astype(np.float32)
    
    ################################################
    
    
    if __name__ == '__main__':
        input_dim = X_train.shape[1]
        output_dim_fc = 4096*2
        batch_size = 10
        subsample_size = 100
        X_batch = get_next_batch(X_train, 0, batch_size)
        X_subsample = get_next_batch(X_train, 0, subsample_size)
    
        with tf.Graph().as_default():
            # inputs
            x = tf.placeholder(tf.float32, shape=[None, input_dim], name="x")
            x_subsample = tf.placeholder(tf.float32, shape=[None, input_dim], name="x_subsample")
    
            # reshape vector inputs to images
            side_size = int(np.sqrt(input_dim))
            x_image = tf.reshape(x, [-1, side_size, side_size, 1])
            x_subsample_image = tf.reshape(x_subsample, [subsample_size, side_size, side_size, 1])
    
            # fully connected ops
            out_fc_x = fully_connected(x, output_dim_fc, act=tf.nn.relu)
            out_fc_subsample = fully_connected(x_subsample, output_dim_fc, act=tf.nn.relu)
    
            # convolution ops
            out_conv_x = convolution_mnist(x_image)
            out_conv_subsample = convolution_mnist(x_subsample_image)
    
            init_dim = np.prod([s.value for s in out_conv_x.shape[1:] if s.value is not None])
            x_conv_flat = tf.reshape(out_conv_x, [-1, init_dim])
            subsample_conv_flat = tf.reshape(out_conv_subsample, [subsample_size, init_dim])
    
            # kernel computing ops
            with tf.device('/cpu:0'):
                kernel_cpu = tf_rbf_kernel(x_conv_flat, subsample_conv_flat, gamma=0.001)
            with tf.device('/device:GPU:0'):
                kernel_gpu = tf_rbf_kernel(x_conv_flat, subsample_conv_flat, gamma=0.001)
    
            feed_dict = {x: X_batch, x_subsample: X_subsample}
    
            def kernel_sklearn():
                with tf.Session() as sess:
                    init = tf.global_variables_initializer()
                    sess.run([init])
                    x, y = sess.run([x_conv_flat, subsample_conv_flat], feed_dict=feed_dict)
                rbf_kernel(x, y, gamma=0.001)
    
    
            # todo regarder le temps de la retro propagation
            # todo kernel tensorflow on cpu
            # todo kernel tensorflow on gpu
            # todo kernel sklearn on cpu
    
            d_time_results = {
                "fc_x": time_fct(lambda: tf_op(feed_dict, [out_fc_x])),
                "fc_subsample": time_fct(lambda: tf_op(feed_dict, [out_fc_subsample])),
                "reshape_x": time_fct(lambda: tf_op(feed_dict, [x_image])),
                "reshape_subsample": time_fct(lambda: tf_op(feed_dict, [x_subsample_image])),
                "reshape_x + conv_x": time_fct(lambda: tf_op(feed_dict, [out_conv_x])),
                "reshape_subsample + conv_subsample": time_fct(lambda: tf_op(feed_dict, [out_fc_subsample])),
                "reshape_x + conv_x + reshape_subsample + conv_subsample + kernel_cpu": time_fct(lambda: tf_op(feed_dict, [kernel_cpu])),
                "reshape_x + conv_x + reshape_subsample + conv_subsample + kernel_gpu": time_fct(lambda: tf_op(feed_dict, [kernel_gpu])),
                "reshape_x + conv_x + reshape_subsample + conv_subsample + kernel_sklearn": time_fct(kernel_sklearn)
            }
    
            for key, value in d_time_results.items():
                print("{}:\t{:.4f}".format(key, value))