Skip to content
Snippets Groups Projects
Commit 85230ebd authored by Luc Giffon's avatar Luc Giffon
Browse files

nystrom approx now use common function for conv relu pooling

parent 13fac76c
No related branches found
No related tags found
No related merge requests found
......@@ -5,10 +5,9 @@ Convnet with nystrom approximation of the feature map.
import tensorflow as tf
import numpy as np
from sklearn.metrics.pairwise import rbf_kernel
import skluc.mldatasets as dataset
from skluc.neural_networks import bias_variable, weight_variable, conv2d, max_pool_2x2, conv_relu_pool, get_next_batch
from skluc.neural_networks import bias_variable, weight_variable, conv_relu_pool, get_next_batch
tf.logging.set_verbosity(tf.logging.ERROR)
......@@ -43,29 +42,12 @@ NYSTROM_SAMPLE_SIZE = 500
X_nystrom = X_train[np.random.permutation(NYSTROM_SAMPLE_SIZE)]
def convolution_mnist(input, trainable=True):
with tf.name_scope("conv_pool_1"):
# 32 is the number of filter we'll use. e.g. the number of different
# shapes this layer is able to recognize
W_conv1 = weight_variable([5, 5, 1, 20], trainable=trainable)
tf.summary.histogram("weights conv1", W_conv1)
b_conv1 = bias_variable([20], trainable=trainable)
tf.summary.histogram("biases conv1", b_conv1)
# -1 is here to keep the total size constant (784)
h_conv1 = tf.nn.relu(conv2d(input, W_conv1) + b_conv1)
tf.summary.histogram("act conv1", h_conv1)
h_pool1 = max_pool_2x2(h_conv1)
with tf.name_scope("conv_pool_2"):
W_conv2 = weight_variable([5, 5, 20, 50], trainable=trainable)
tf.summary.histogram("weights conv2", W_conv2)
b_conv2 = bias_variable([50], trainable=trainable)
tf.summary.histogram("biases conv2", b_conv2)
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
tf.summary.histogram("act conv2", h_conv2)
h_pool2 = max_pool_2x2(h_conv2)
return h_pool2
def convolution_mnist(input_, trainable=True):
with tf.variable_scope("conv_pool_1"):
conv1 = conv_relu_pool(input_, [5, 5, 1, 20], [20], trainable=trainable)
with tf.variable_scope("conv_pool_2"):
conv2 = conv_relu_pool(conv1, [5, 5, 20, 50], [50], trainable=trainable)
return conv2
def fully_connected(conv_out):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment