Skip to content
Snippets Groups Projects
Commit 3dc1e08d authored by Luc Giffon's avatar Luc Giffon
Browse files

nystrom is now embeded in a fct and in a namespace

parent 5a9c44bd
No related branches found
No related tags found
No related merge requests found
......@@ -93,6 +93,21 @@ def tf_rbf_kernel(X, Y, gamma):
return K
def nystrom_layer(input_x, input_subsample, gamma):
with tf.name_scope("nystrom"):
init_dim = np.prod([s.value for s in input_x.shape[1:] if s.value is not None])
h_conv_flat = tf.reshape(input_x, [-1, init_dim])
h_conv_nystrom_subsample_flat = tf.reshape(input_subsample, [NYSTROM_SAMPLE_SIZE, init_dim])
with tf.name_scope("kernel_vec"):
kernel_vector = tf_rbf_kernel(h_conv_flat, h_conv_nystrom_subsample_flat, gamma=gamma)
D = weight_variable((NYSTROM_SAMPLE_SIZE,))
V = weight_variable((NYSTROM_SAMPLE_SIZE, NYSTROM_SAMPLE_SIZE))
out_fc = tf.matmul(kernel_vector, tf.matmul(tf.multiply(D, V), tf.transpose(V)))
return out_fc
def main():
GAMMA = 0.001
print("Gamma = {}".format(GAMMA))
......@@ -116,16 +131,7 @@ def main():
scope_conv_mnist.reuse_variables()
h_conv_nystrom_subsample = convolution_mnist(x_nystrom_image, trainable=False)
init_dim = np.prod([s.value for s in h_conv.shape[1:] if s.value is not None])
h_conv_flat = tf.reshape(h_conv, [-1, init_dim])
h_conv_nystrom_subsample_flat = tf.reshape(h_conv_nystrom_subsample, [NYSTROM_SAMPLE_SIZE, init_dim])
with tf.name_scope("kernel_vec"):
kernel_vector = tf_rbf_kernel(h_conv_flat, h_conv_nystrom_subsample_flat, gamma=GAMMA)
D = weight_variable((NYSTROM_SAMPLE_SIZE,))
V = weight_variable((NYSTROM_SAMPLE_SIZE, NYSTROM_SAMPLE_SIZE))
out_fc = tf.matmul(kernel_vector, tf.matmul(tf.multiply(D, V), tf.transpose(V)))
out_fc = nystrom_layer(h_conv, h_conv_nystrom_subsample, GAMMA)
# classification
with tf.name_scope("fc_2"):
......@@ -177,10 +183,8 @@ def main():
feed_dict = {x: X_batch, y_: Y_batch, keep_prob: 0.5}
# le _ est pour capturer le retour de "train_optimizer" qu'il faut appeler
# pour calculer le gradient mais dont l'output ne nous interesse pas
_, loss, y_result, x_exp, k_vec, eigenvec = sess.run([train_optimizer, cross_entropy, y_conv, x_image, kernel_vector, V], feed_dict=feed_dict)
_, loss, y_result, x_exp = sess.run([train_optimizer, cross_entropy, y_conv, x_image], feed_dict=feed_dict)
if i % 100 == 0:
print(k_vec[0])
print("Difference with identity:", np.linalg.norm(eigenvec - np.eye(*eigenvec.shape)))
print('step {}, loss {} (with dropout)'.format(i, loss))
r_accuracy = sess.run([accuracy], feed_dict=feed_dict_val)
print("accuracy: {} on validation set (without dropout).".format(r_accuracy))
......@@ -198,4 +202,4 @@ def main():
if __name__ == '__main__':
main()
\ No newline at end of file
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment