Keras-在损失函数中打印中间张量(tf.Print和K.print_tensor不起作用...)

时间:2018-08-26 04:10:40

标签: python tensorflow keras

我为Keras模型编写了一个相当复杂的损失函数,并且在训练过程中不断返回nan。因此,我需要在训练时打印中间张量。我了解您无法在损失函数中执行K.eval,因为张量未初始化。但是,我尝试了K.print_tensor()和tf.Print(),但都没有用。

我几乎想做这样的事情:

def mean_squared_error(y_true, y_pred):
    print("mean_squared_error")
    loss = K.mean(K.square(y_pred - y_true), axis=-1)
    loss = tf.Print(loss, [loss])
    return loss
model.compile(optimizer=self.optimizer, loss=mean_squared_error)

实际上,我会用我的自定义损失替换mean_squared_error。将显示“ mean_squared_error”,但不会打印我尝试使用张量流打印(也不是Keras打印)的值。我还尝试了与Keras print inside loss function does not work中完全相同的代码,但仍然看不到控制台上打印任何内容。

此外,我还编写了一个单独的文件来测试某些内容。

import tensorflow as tf
import keras.backend as K

input1 = K.constant(1)
input2 = K.constant(2)
input3 = K.constant(3)

node1 = tf.add(input1, input2)
print_output = K.print_tensor(node1)
output = tf.multiply(print_output, input3)

也没有打印任何内容。

我是否使用tensorflow Print和Keras print_tensor错误?还是将结果打印在其他地方?我尝试使用print("test", file=sys.stderr)测试控制台的stderr,并获得正确的输出test

为澄清起见,我知道您可以使用K.eval来使测试代码打印出张量的值,但是由于我无法在损失函数中使用K.eval,因此我需要将tf.Print或K设置为。 print_tensor工作。

提前谢谢!

4 个答案:

答案 0 :(得分:1)

这里的问题是训练代码通常实际上并不取决于损耗张量的值!通常,您无需计算损失的实际值就可以计算损失的梯度,这意味着tensorflow的运行时可以自由地从图中修剪损失的实际执行。

您可以将损失函数包装在tf.contrib.eager.defun装饰器中,这样做的副作用是,即使您的向后传递不需要所有有状态的操作,它们也可以运行。

答案 1 :(得分:0)

如果要运行操作并在不通过会话的情况下打印结果,则必须使用tf.InteractiveSession-请查看详细信息here

因此,如果更改如下,您的测试代码将打印 node1 值:

import tensorflow as tf
import keras.backend as K

input1 = K.constant(1)
input2 = K.constant(2)
input3 = K.constant(3)

node1 = tf.add(input1, input2)
print_output = K.print_tensor(node1)
output = tf.multiply(print_output, input3)
sess = tf.InteractiveSession()
print("node1: ", node1.eval())
sess.close()

答案 2 :(得分:0)

您的代码建立了一个图形:

import tensorflow as tf
import keras.backend as K

input1 = K.constant(1)
input2 = K.constant(2)
input3 = K.constant(3)

node1 = tf.add(input1, input2)
print_output = K.print_tensor(node1)
output = tf.multiply(print_output, input3)

为了运行图形,您需要定义一个Session环境,在其中执行Operation对象并评估Tensor对象:

sess = tf.Session()

要评估张量输出:

sess.run(output)

最后,释放资源:

sess.close()

您的代码仅定义了图形。没有会话,也没有评估操作。

答案 3 :(得分:0)

在TensorFlow 2中,类似于this answer中建议的解决方案,您可以使用@tf.function装饰损失函数。