矩阵上的张量流扫描

时间:2018-04-17 13:47:02

标签: python tensorflow scanning

如何让以下扫描示例正常工作?我打算用这个例子来测试函数f中的一些代码。

def f(prev_y, curr_y):
  fval = tf.nn.softmax(curr_y)
  return fval

a = tf.constant([[.1, .25, .3, .2, .15],
                 [.07, .35, .27, .17, .14]])
c = tf.scan(f, a, initializer=0)
with tf.Session() as sess:
  print(sess.run(c))

1 个答案:

答案 0 :(得分:1)

您的initializer=0无效。正如documented

  

如果提供initializer,则fn的输出必须与initializer具有相同的结构;并且fn的第一个参数必须与此结构匹配。

f的输出与curr_y的类型和形状相同,第二个参数与0不匹配。在这种情况下,您需要:

init = tf.constant([0., 0., 0., 0., 0.])
c = tf.scan(f, a, initializer=init)

对于这种特定情况(特别是您的f),您不会(也可能不应该)使用tf.scan,因为您的f只使用一个参数。 tf.map_fn可以完成这项工作:

c = tf.map_fn(tf.nn.softmax, a)