diff --git a/skluc/main/keras_/kernel.py b/skluc/main/keras_/kernel.py index 653be6cc6a7e517558259e069ddcf847fcccc7bc..807a3352e8546beb58462d2db507e4875aee970d 100644 --- a/skluc/main/keras_/kernel.py +++ b/skluc/main/keras_/kernel.py @@ -26,6 +26,23 @@ def keras_chi_square_CPD(args): K = - tf.reduce_sum(quotient_without_nan, axis=2) return K +def chi2_kernel(args): + x = args[0] + y = tf.concat(args[1:], 0) + + x = tf.nn.l2_normalize(x, axis=-1) + y = tf.nn.l2_normalize(y, axis=-1) + # the drawing of the matrix X expanded looks like a wall + wall = tf.expand_dims(x, axis=1) + # the drawing of the matrix Y expanded looks like a floor + floor = tf.expand_dims(y, axis=0) + numerator = tf.square(tf.subtract(wall, floor)) + denominator = tf.add(wall, floor) + 0.001 + quotient = numerator / denominator + quotient_without_nan = quotient #replace_nan(quotient) + + K = - tf.reduce_sum(quotient_without_nan, axis=2) + return tf.tanh(K) if __name__ == '__main__': a = tf.Constant(value=0) \ No newline at end of file