使用张量流

时间:2017-02-14 18:48:38

标签: debugging tensorflow keras

我正在编写一个自定义目标来训练Keras(带有tensorflow后端)模型,但我需要调试一些中间计算。为简单起见,假设我有:

def custom_loss(y_pred, y_true):
    diff = y_pred - y_true
    return K.square(diff)

我无法找到一种简单的方法来访问,例如,训练期间中间变量diff或其形状。在这个简单的例子中,我知道我可以返回diff来打印它的值,但是我的实际损失更复杂,我不能在没有编译错误的情况下返回中间值。

在Keras中有一种简单的方法来调试中间变量吗?

2 个答案:

答案 0 :(得分:8)

据我所知,这不是Keras解决的问题,因此您必须采用特定于后端的功能。 TheanoTensorFlow都有Print个节点作为标识节点(即它们返回输入节点)并具有打印输入的副作用(或输入的某些张量)

Theano的例子:

diff = y_pred - y_true
diff = theano.printing.Print('shape of diff', attrs=['shape'])(diff)
return K.square(diff)

TensorFlow示例:

diff = y_pred - y_true
diff = tf.Print(diff, [tf.shape(diff)])
return K.square(diff)

请注意,这仅适用于中间值。 Keras希望传递给其他层的张量具有特定属性,例如_keras_shape。后端处理的值,即通过Print,通常没有该属性。要解决此问题,您可以将调试语句包装在Lambda层中,例如。

答案 1 :(得分:0)

在TensorFlow 2中,您现在可以在TensorFlow Keras模型/层/损耗中添加IDE断点,包括使用拟合,评估和预测方法时。但是,您必须在调用model.run_eagerly = True 之后添加model.compile(),以使张量的值在调试器中的断点处可用。例如,

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

def custom_loss(y_pred, y_true):
    diff = y_pred - y_true
    return tf.keras.backend.square(diff)  # Breakpoint in IDE here. =====

class SimpleModel(Model):

    def __init__(self):
        super().__init__()
        self.dense0 = Dense(2)
        self.dense1 = Dense(1)

    def call(self, inputs):
        z = self.dense0(inputs)
        z = self.dense1(z)
        return z

x = tf.convert_to_tensor([[1, 2, 3], [4, 5, 6]], dtype=tf.float32)
y = tf.convert_to_tensor([0, 1], dtype=tf.float32)

model0 = SimpleModel()
model0.run_eagerly = True
model0.compile(optimizer=Adam(), loss=custom_loss)
y0 = model0.fit(x, y, epochs=1)  # Values of diff *not* shown at breakpoint. =====

model1 = SimpleModel()
model1.compile(optimizer=Adam(), loss=custom_loss)
model1.run_eagerly = True
y1 = model1.fit(x, y, epochs=1)  # Values of diff shown at breakpoint. =====

这也可用于调试中间网络层的输出(例如,在SimpleModel的call中添加断点)。

注意:这已在TensorFlow 2.0.0-rc0中进行了测试。