Keras-Tensorflow:在推理时打印张量的值

时间:2018-09-09 21:40:54

标签: python tensorflow keras

我正在尝试在Keras的自定义层中注入一些代码来计算和显示中间变量(在训练期间未使用),该变量在此处定义:https://github.com/sadeepj/crfasrnn_keras/blob/master/src/crfrnn_layer.py,其call函数的外观类似于以下内容(为了清晰起见,我删除了部分代码):

def call(self, inputs):
    unaries = tf.transpose(inputs[0][0, :, :, :], perm=(2, 0, 1))
    rgb = tf.transpose(inputs[1][0, :, :, :], perm=(2, 0, 1))

    ... some code here...

    q_values = unaries

    for i in range(self.num_iterations):
        softmax_out = tf.nn.softmax(q_values, 0)

        ... some code here...

        # Compatibility transform
        pairwise = tf.matmul(self.compatibility_matrix, message_passing)

        # Adding unary potentials
        pairwise = tf.reshape(pairwise, (c, h, w))
        q_values = unaries - pairwise

    return tf.transpose(tf.reshape(q_values, (1, c, h, w)), perm=(0, 2, 3, 1))

我的目标是计算一个名为energy的标量值,并使用以下代码显示(替换上面代码的最后四行):

        # Adding unary potentials
        pairwise = tf.reshape(pairwise, (c, h, w))
        q_values = unaries - pairwise

        # Compute and display the energy at the current iteration
        energy = tf.reduce_sum(tf.multiply(tf.subtract(0.5*pairwise, unaries), q_values))
        energy = keras.backend.print_tensor(energy, message='energy = ') # not working

    return tf.transpose(tf.reshape(q_values, (1, c, h, w)), perm=(0, 2, 3, 1))

使用keras.backend.print_tensor无效(未打印任何内容)。

有人可以告诉我如何显示energy的值吗?

非常感谢您的帮助!

0 个答案:

没有答案