TypeError:预期的float32乘法占位符和渐变

时间:2016-06-25 23:00:06

标签: tensorflow

嘿,我正在努力让自己熟悉张量流,并且遇到这个错误,google并没有提供太多...

错误源自' c'占位符乘法。我删除它时错误消失

代码:

x = tf.placeholder(tf.float32)
c = tf.placeholder(tf.float32)
y = x**2
g = tf.gradients(y, x) * c

tf.Session().run(g, {x:[1,1],c:[-1,-1]})

错误:

TypeError: Expected float32, got list containing Tensors of type '_Message' instead.

1 个答案:

答案 0 :(得分:2)

这里的问题是tf.gradients()返回张量的列表(即使它的参数是单个张量......不幸的是,它与其他一些TensorFlow API不一致)。因此,您必须获取返回值的第0个元素:

g = tf.gradients(y, x)[0] * c