如何让以下扫描示例正常工作?我打算用这个例子来测试函数f中的一些代码。
def f(prev_y, curr_y):
fval = tf.nn.softmax(curr_y)
return fval
a = tf.constant([[.1, .25, .3, .2, .15],
[.07, .35, .27, .17, .14]])
c = tf.scan(f, a, initializer=0)
with tf.Session() as sess:
print(sess.run(c))
答案 0 :(得分:1)
您的initializer=0
无效。正如documented:
如果提供
initializer
,则fn的输出必须与initializer
具有相同的结构;并且fn
的第一个参数必须与此结构匹配。
f
的输出与curr_y
的类型和形状相同,第二个参数与0
不匹配。在这种情况下,您需要:
init = tf.constant([0., 0., 0., 0., 0.])
c = tf.scan(f, a, initializer=init)
对于这种特定情况(特别是您的f
),您不会(也可能不应该)使用tf.scan
,因为您的f
只使用一个参数。 tf.map_fn
可以完成这项工作:
c = tf.map_fn(tf.nn.softmax, a)