diff --git a/skluc/main/tensorflow_/kernel.py b/skluc/main/tensorflow_/kernel.py index af5b20de76b11277d77dac18dd1b8a25792fbf3b..0f8f9a0c6b795d4fc45b7ca9dcfd9c41090cf46b 100644 --- a/skluc/main/tensorflow_/kernel.py +++ b/skluc/main/tensorflow_/kernel.py @@ -28,6 +28,10 @@ def tf_rbf_kernel(X, Y, gamma): def tf_linear_kernel(X, Y): return tf.matmul(X, tf.transpose(Y)) +def tf_polynomial_kernel(X, Y, degree=2, gamma=None, **kwargs): + if gamma is None: + gamma = tf.div(tf.constant(1, dtype=tf.float32), X.shape[-1].value) + return tf.pow(tf.add(tf.constant(1, dtype=tf.float32), gamma * tf_linear_kernel(X, Y)), tf.constant(degree, dtype=tf.float32)) def tf_chi_square_PD(X, Y): # the drawing of the matrix X expanded looks like a wall diff --git a/skluc/main/tensorflow_/kernel_approximation/fastfood_layer.py b/skluc/main/tensorflow_/kernel_approximation/fastfood_layer.py index b410390cd98ed545a0d81157f63f81365cb9915f..d657dd75000565c8bab25f8662d6f5784a4c8a28 100644 --- a/skluc/main/tensorflow_/kernel_approximation/fastfood_layer.py +++ b/skluc/main/tensorflow_/kernel_approximation/fastfood_layer.py @@ -155,3 +155,89 @@ def is_power_of_two(input_integer): def build_hadamard(n_neurons): return scipy.linalg.hadamard(n_neurons) + + +class FastFoodLayer(tf.keras.layers.Layer): + def __init__(self, sigma, nbr_stack, trainable=True): + super().__init__(self) + self.__sigma = sigma + self.__nbr_stack = nbr_stack + self.__trainable = trainable + + def build(self, input_shape): + # todo replace these with proper initializers? + with tf.name_scope("fastfood" + "_sigma-" + str(self.__sigma)): + init_dim = np.prod([s.value for s in input_shape if s.value is not None]) + final_dim = int(dimensionality_constraints(init_dim)) + self.num_outputs = None + + G, G_norm = G_variable((self.__nbr_stack, final_dim)) + self.__G = self.add_variable( + name="G", + shape=(self.__nbr_stack, final_dim), + initializer=lambda *args, **kwargs: G, + trainable=self.__trainable + ) + + B = B_variable((self.__nbr_stack, final_dim)) + self.__B = self.add_variable( + name="B", + shape=(self.__nbr_stack, final_dim), + initializer=lambda *args, **kwargs: B, + trainable=self.__trainable + ) + + H = H_variable(final_dim) + self.__H = self.add_variable( + name="H", + shape=(final_dim, final_dim), + initializer=lambda *args, **kwargs: H, + trainable=False + ) + + P = P_variable(final_dim, self.__nbr_stack) + self.__P = self.add_variable( + name="P", + shape=(final_dim * self.__nbr_stack, final_dim * self.__nbr_stack), + initializer=lambda *args, **kwargs: P, + trainable=False + ) + + S = S_variable((self.__nbr_stack, final_dim), G_norm) + self.__S = self.add_variable( + name="S", + shape=(final_dim * self.__nbr_stack, final_dim * self.__nbr_stack), + initializer=lambda *args, **kwargs: S, + trainable=self.__trainable + ) + + self.num_outputs = final_dim * self.__nbr_stack + + def call(self, input, **kwargs): + init_dim = np.prod([s.value for s in input.shape if s.value is not None]) + final_dim = int(dimensionality_constraints(init_dim)) + + padding = final_dim - init_dim + conv_out2 = tf.reshape(input, [-1, init_dim]) + paddings = tf.constant([[0, 0], [0, padding]]) + conv_out2 = tf.pad(conv_out2, paddings, "CONSTANT") + conv_out2 = tf.reshape(conv_out2, (1, -1, 1, final_dim)) + h_ff1 = tf.multiply(conv_out2, self.__B, name="Bx") + h_ff1 = tf.reshape(h_ff1, (-1, final_dim)) + h_ff2 = tf.matmul(h_ff1, self.__H, name="HBx") + h_ff2 = tf.reshape(h_ff2, (-1, final_dim * self.__nbr_stack)) + h_ff3 = tf.matmul(h_ff2, self.__P, name="PHBx") + h_ff4 = tf.multiply(tf.reshape(h_ff3, (-1, final_dim * self.__nbr_stack)), + tf.reshape(self.__G, (-1, final_dim * self.__nbr_stack)), + name="GPHBx") + h_ff4 = tf.reshape(h_ff4, (-1, final_dim)) + h_ff5 = tf.matmul(h_ff4, self.__H, name="HGPHBx") + + h_ff6 = tf.scalar_mul((1 / (self.__sigma * np.sqrt(final_dim))), + tf.multiply(tf.reshape(h_ff5, (-1, final_dim * self.__nbr_stack)), + tf.reshape(self.__S, (-1, final_dim * self.__nbr_stack)), name="SHGPHBx")) + h_ff7_1 = tf.cos(h_ff6) + h_ff7_2 = tf.sin(h_ff6) + h_ff7 = tf.scalar_mul(tf.sqrt(float(1 / final_dim)), tf.concat([h_ff7_1, h_ff7_2], axis=1)) + return h_ff7 + diff --git a/skluc/main/tensorflow_/kernel_approximation/nystrom_layer.py b/skluc/main/tensorflow_/kernel_approximation/nystrom_layer.py index cc1cd51a3497b49e185ea653dc29dff681626e93..3ec4730ebbf2de3a5064572ed1a6bf8694ef9937 100644 --- a/skluc/main/tensorflow_/kernel_approximation/nystrom_layer.py +++ b/skluc/main/tensorflow_/kernel_approximation/nystrom_layer.py @@ -6,9 +6,10 @@ import time as t import numpy as np import tensorflow as tf +from sklearn.metrics.pairwise import rbf_kernel, linear_kernel, additive_chi2_kernel, chi2_kernel, laplacian_kernel import skluc.main.data.mldatasets as dataset -from skluc.main.tensorflow_.kernel import tf_rbf_kernel +from skluc.main.tensorflow_.kernel import tf_rbf_kernel, tf_linear_kernel, tf_chi_square_CPD, tf_chi_square_CPD_exp, tf_laplacian_kernel from skluc.main.tensorflow_.utils import get_next_batch, classification_mnist, convolution_mnist from skluc.main.utils import logger @@ -157,5 +158,157 @@ def main(): print("Elapsed time: %.4f s" % (stoped - started)) +class DeepstromLayer(tf.keras.layers.Layer): + def __init__(self, + subsample, + kernel_name, + out_dim=None, + activation=None, + real_nystrom=False, + sum_of_kernels=False, + stack_of_kernels=False, + kernel_dict={} + ): + + def init_kernel(): + if kernel_name == "rbf": + kernel_fct = rbf_kernel + tf_kernel_fct = tf_rbf_kernel + elif kernel_name == "linear": + kernel_fct = linear_kernel + tf_kernel_fct = tf_linear_kernel + elif kernel_name == "chi2_cpd": + kernel_fct = additive_chi2_kernel + tf_kernel_fct = tf_chi_square_CPD + elif kernel_name == "chi2_exp_cpd": + kernel_fct = chi2_kernel + tf_kernel_fct = tf_chi_square_CPD_exp + elif kernel_name == "chi2_pd": + raise NotImplementedError("Bien verifier que ce code ne fait pas bordel") + elif kernel_name == "laplacian": + tf_kernel_fct = tf_laplacian_kernel + kernel_fct = laplacian_kernel + else: + raise ValueError("Unknown kernel name: {}".format(kernel_name)) + return kernel_name, kernel_fct, tf_kernel_fct, kernel_dict + + def init_output_dim(subsample_size): + if out_dim is not None and out_dim > subsample_size: + logger.debug("Output dim is greater than deepstrom subsample size. Aborting.") + exit() + elif out_dim is None: + return subsample_size + else: + return out_dim + + def init_activation(): + if activation == "tan": + activation_fct = tf.nn.tanh + elif activation == "relu": + activation_fct = tf.nn.relu + else: + activation_fct = activation + + return activation_fct + + def init_real_nystrom(subsample_size): + if out_dim != subsample_size and out_dim is not None and real_nystrom: + logger.warning("If real nystrom is used, the output dim can only be the same as subsample size: " + "{} != {}".format(out_dim, subsample_size)) + + return real_nystrom + + def init_subsample(): + return subsample, len(subsample) + + super().__init__() + + self.__subsample, self.__subsample_size = init_subsample() + + self.__sum_of_kernels = sum_of_kernels + self.__stack_of_kernels = stack_of_kernels + + self.__kernel_name, self.__kernel_fct, self.__tf_kernel_fct, self.__kernel_dict = init_kernel() + self.__real_nystrom = init_real_nystrom(self.__subsample_size) + self.__output_dim = init_output_dim(self.__subsample_size) + self.__activation = init_activation() + self.__W_matrix = None + + logger.info("Selecting {} deepstrom layer function with " + "subsample size = {}, " + "output_dim = {}, " + "{} activation function " + "and kernel = {}" + .format("real" if self.__real_nystrom else "learned", + self.__subsample_size, + self.__output_dim, + "with" if self.__activation else "without", + self.__kernel_name)) + + def build(self, input_shape): + if self.__output_dim != 0: + # outputdim == 0 means there is no W matrix and the kernel vector is directly added as input to + # the next layer + if self.__real_nystrom: + W_matrix = None + logger.debug("Real nystrom asked: eg W projection matrix has the vanilla formula") + if self.__sum_of_kernels: + raise NotImplementedError("This has not been checked for a while. Must be updated.") + # here K11 matrix are added before doing nystrom approximation + # added_K11 = np.zeros((self.__subsample.shape[0], self.__subsample.shape[0])) + # for g_value in GAMMA: # only rbf kernel is considered + # added_K11 = np.add(added_K11, rbf_kernel(self.__subsample, self.__subsample, gamma=g_value)) + # U, S, V = np.linalg.svd(added_K11) + # invert_root_K11 = np.dot(U / np.sqrt(S), V).astype(np.float32) + # W_matrix = stack_K11 + elif self.__stack_of_kernels: + raise NotImplementedError("This has not been checked for a while. Must be updated.") + # here nystrom approximations are stacked + # lst_invert_root_K11 = [] + # for g_value in GAMMA: + # K11 = rbf_kernel(self.__subsample, self.__subsample, gamma=g_value) + # U, S, V = np.linalg.svd(K11) + # invert_root_K11 = np.dot(U / np.sqrt(S), V).astype(np.float32) + # lst_invert_root_K11.append(invert_root_K11) + # stack_K11 = np.vstack(lst_invert_root_K11) + # W_matrix = stack_K11 + else: + K11 = self.__kernel_fct(self.__subsample, self.__subsample, **self.__kernel_dict) + U, S, V = np.linalg.svd(K11) + invert_root_K11 = np.dot(U / np.sqrt(S), V).astype(np.float32) + W_matrix = invert_root_K11 + + self.__W_matrix = self.add_variable( + name="W_nystrom", + shape=[self.__subsample_size, self.__subsample_size], + initializer=lambda *args, **kwargs: tf.Variable(initial_value=W_matrix), + trainable=False + ) + else: + self.__W_matrix = self.add_variable( + name="W_nystrom", + shape=[self.__subsample_size, self.__output_dim], + initializer=tf.random_normal_initializer(stddev=0.1), + trainable=True + ) + + def call(self, input_x, **kwargs): + with tf.name_scope("NystromLayer"): + init_dim = np.prod([s.value for s in input_x.shape[1:] if s.value is not None]) + h_conv_flat = tf.reshape(input_x, [-1, init_dim]) + h_conv_nystrom_subsample_flat = tf.reshape(self.__subsample, [self.__subsample_size, init_dim]) + with tf.name_scope("kernel_vec"): + kernel_vector = self.__tf_kernel_fct(h_conv_flat, h_conv_nystrom_subsample_flat, **self.__kernel_dict) + + if self.__output_dim != 0: + out = tf.matmul(kernel_vector, self.__W_matrix) + else: + out = kernel_vector + if self.__activation is not None: + out = self.__activation(out) + return out + + + if __name__ == '__main__': main() diff --git a/skluc/main/tensorflow_/utils.py b/skluc/main/tensorflow_/utils.py index db494afcf96247ff9257789c25ff5918c638217b..c184d2760fd65deff2a8eaf36255b56ad4c021c5 100644 --- a/skluc/main/tensorflow_/utils.py +++ b/skluc/main/tensorflow_/utils.py @@ -178,10 +178,10 @@ def classification_mnist(input_, output_dim): :param output_dim: The returned output dimension :return: """ + # todo remove this useless thing with tf.variable_scope("classification_mnist"): keep_prob = tf.placeholder(tf.float32, name="keep_prob") input_drop = tf.nn.dropout(input_, keep_prob) - # todo why no softmax activation, here? y_ = fully_connected(input_drop, output_dim) return y_, keep_prob diff --git a/skluc/main/tools/experiences/cluger.py b/skluc/main/tools/experiences/cluger.py index 9729850c037114c112bcdc82d9174afe45fc5246..77a26c0d914ae1badaaf440db952837b1dc9d2e5 100644 --- a/skluc/main/tools/experiences/cluger.py +++ b/skluc/main/tools/experiences/cluger.py @@ -11,7 +11,7 @@ Options: -a --array-params PARAMFILE The path to a file containing an array of parameters -n --dry-run Tell the script not to be run but print the command lines instead -m --maximum-job integer Tell how many simultaneous jobs should be launched. - -H --host str The name of the hosts on which to run the scripts separated by commas[default: see4c1]. + -H --host str The name of the hosts on which to run the scripts separated by commas. -S --start integer The number of the starting batch [default: 0]. -t --walltime integer The time in hour for each job. """ @@ -34,7 +34,7 @@ if __name__ == '__main__': MAX_LINES = int(arguments["--maximum-job"]) if arguments["--maximum-job"] is not None else math.inf DRY_RUN = arguments["--dry-run"] INTERPRETER = arguments["--python"] - HOST = [h.strip() for h in arguments["--host"].split(",")] + HOST = [h.strip() for h in arguments["--host"].split(",")] if arguments["--host"] is not None else None START_LINE = int(arguments["--start"]) TIME = int(arguments["--walltime"]) diff --git a/skluc/main/tools/experiences/gather_results.py b/skluc/main/tools/experiences/gather_results.py index 70532286178c4416564a6491c00b9bc9ca13a7cb..92d8e28d42f0e36f78c775f82f04b68602fd545f 100644 --- a/skluc/main/tools/experiences/gather_results.py +++ b/skluc/main/tools/experiences/gather_results.py @@ -4,12 +4,14 @@ gather_results: gather OAR results from one dir to one file called gathered_resu The result files should consist of lines and should have "stdout" in their name. Usage: - gather_results -i IPATH [-p regex] + gather_results -i IPATH [-p regex] [--header] [--verbose] Options: -h --help Show this screen. -i --input-dir=<IPATH> Input directory wher to find results -p --patern=regex Specify the pattern of the files to be looked at [default: .+\.stdout]. + -r --header Says if there is a header in the result files. + -v --verbose Print the lines of the final file """ import os @@ -32,19 +34,35 @@ if __name__ == '__main__': count = 0 results = [] compiled_re = re.compile(pattern_to_recognize) + first_line = "" for f_name in onlyfiles: if compiled_re.match(f_name) is None: continue with open(f_name, "r") as f: - str_f = f.read().strip() - results.append(str_f) + lines = f.readlines() + try: + if not first_line and arguments["--header"]: + first_line = lines[0].strip() + if arguments["--header"]: + results.append(lines[1].strip()) + else: + results.append(lines[0].strip()) + except IndexError: + results.append("") with open(os.path.join(mypath, "gathered_results.csv"), 'w') as f_w: n_full = 0 n_empty = 0 + if first_line: + f_w.write(first_line) + f_w.write("\n") + if arguments["--verbose"]: + logger.debug(first_line) for s in results: f_w.write(s) f_w.write("\n") + if arguments["--verbose"]: + logger.debug(s) if s.strip() != "": n_full += 1 else: diff --git a/skluc/test/test_kernel.py b/skluc/test/test_kernel.py index b30af26b550389ec0fbe7a710d21bdf0c357ca0f..8bc432565624d072f2ddc8fdfc46eda945eebbcb 100644 --- a/skluc/test/test_kernel.py +++ b/skluc/test/test_kernel.py @@ -1,11 +1,11 @@ import unittest import tensorflow as tf import numpy as np -from sklearn.metrics.pairwise import rbf_kernel, chi2_kernel, additive_chi2_kernel, sigmoid_kernel, laplacian_kernel +from sklearn.metrics.pairwise import rbf_kernel, chi2_kernel, additive_chi2_kernel, sigmoid_kernel, laplacian_kernel, polynomial_kernel from skluc.main.data.mldatasets.MnistDataset import MnistDataset from skluc.main.tensorflow_.kernel import tf_rbf_kernel, tf_chi_square_CPD_exp, tf_chi_square_CPD, tf_sigmoid_kernel, \ - tf_laplacian_kernel, tf_sum_of_kernels, tf_stack_of_kernels + tf_laplacian_kernel, tf_sum_of_kernels, tf_stack_of_kernels, tf_polynomial_kernel class TestKernel(unittest.TestCase): @@ -34,7 +34,8 @@ class TestKernel(unittest.TestCase): "exp_chi2": tf_chi_square_CPD_exp, "chi2": tf_chi_square_CPD, "sigmoid": tf_sigmoid_kernel, - "laplacian": tf_laplacian_kernel + "laplacian": tf_laplacian_kernel, + "poly": tf_polynomial_kernel } self.custom_kernels_params = { @@ -50,7 +51,8 @@ class TestKernel(unittest.TestCase): "exp_chi2": chi2_kernel, "chi2": additive_chi2_kernel, "sigmoid": sigmoid_kernel, - "laplacian": laplacian_kernel + "laplacian": laplacian_kernel, + "poly2": polynomial_kernel } self.sklearn_kernels_params = { @@ -58,9 +60,33 @@ class TestKernel(unittest.TestCase): "exp_chi2": {"gamma": 0.01}, "chi2": {}, "sigmoid": {"gamma": 1 / self.val_size, "coef0": 1.}, - "laplacian": {"gamma": 1 / self.val_size} + "laplacian": {"gamma": 1 / self.val_size}, + "poly2": {"coef0": 1, "degree":2} } + def test_tf_polynomial2_kernel(self): + coef=1 + degrees = [2, 3] + + for degree in degrees: + print(f"Degree {degree}") + expected_rbf_kernel = polynomial_kernel(self.X_val, self.X_val, degree=degree, coef0=coef) + obtained_rbf_kernel = tf_polynomial_kernel(self.X_val, self.X_val, degree=degree).eval() + difference_rbf_kernel = np.linalg.norm(expected_rbf_kernel - obtained_rbf_kernel) / (self.val_size**2) + self.assertAlmostEqual(difference_rbf_kernel, 0, delta=1e-4) + + example1 = self.X_val[0].reshape((1, -1)) + example2 = self.X_val[1].reshape((1, -1)) + expected_rbf_kernel_value = polynomial_kernel(example1, example2, degree=degree, coef0=coef) + obtained_rbf_kernel_value = tf_polynomial_kernel(example1, example2, degree=degree).eval() + difference_rbf_kernel_value = np.linalg.norm(expected_rbf_kernel_value - obtained_rbf_kernel_value) + self.assertAlmostEqual(difference_rbf_kernel_value, 0, delta=1e-4) + + expected_rbf_kernel_vector = polynomial_kernel(example1, self.X_val, degree=degree, coef0=coef) + obtained_rbf_kernel_vector = tf_polynomial_kernel(example1, self.X_val, degree=degree).eval() + difference_rbf_kernel_vector = np.linalg.norm(expected_rbf_kernel_vector - obtained_rbf_kernel_vector) / (self.val_size) + self.assertAlmostEqual(difference_rbf_kernel_vector, 0, delta=1e-4) + def test_tf_rbf_kernel(self): gamma = 0.01 expected_rbf_kernel = rbf_kernel(self.X_val, self.X_val, gamma=gamma)