tf.Print()导致渐变错误

时间:2016-01-13 01:27:41

标签: tensorflow gradients

我试图使用tf.Print调试语句来更好地理解来自compute_gradients()的报告渐变和变量的格式,但遇到了意外问题。训练例程和调试例程(gvdebug)如下:

def gvdebug(g, v):
    #g = tf.Print(g,[g],'G: ')
    #v = tf.Print(v,[v],'V: ')
    g2 = tf.zeros_like(g, dtype=tf.float32)
    v2 = tf.zeros_like(v, dtype=tf.float32)
    g2 = g
    v2 = v
    return g2,v2

# Define training operation
def training(loss, global_step, learning_rate=0.1):
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    grads_and_vars = optimizer.compute_gradients(loss)
    gv2 = [gvdebug(gv[0], gv[1]) for gv in grads_and_vars]
    train_op = optimizer.apply_gradients(gv2, global_step=global_step)
    return train_op

此代码工作正常(但不打印),但如果我取消注释gvdebug()中的两个tf.Print行,我会收到apply_gradients的错误消息:' TypeError:变量必须是tf .Variable&#39 ;.我以为tf.Print只是通过了张贴 - 我做错了什么?

1 个答案:

答案 0 :(得分:2)

<强> TL; DR

请勿尝试tf.Print gv[1],因为它是tf.Variable。它就像一个指向变量的指针,该变量在gradient中创建了gv[0]

更多信息

当您运行compute_gradients时,它会返回gradients列表及其对应的tf.Variables

grads_and_vars的每个元素都是Tensortf.Variable。重要的是要注意它是变量的值。

删除v = tf.Print(v,[v],'V: ')

后,您的代码适用于我