相当于在tensorflow中枚举以使用tf.map_fn中的索引

时间:2018-06-01 10:16:54

标签: python tensorflow

我有一个张量,我使用tf.map_fn逐行处理。现在我想将索引作为参数包含在我传递给tf.map_fn的函数中。在numpy中我可以使用enumerate获取该信息并将其传递给我的lambda函数。这是numpy中的一个例子,我将0添加到第一行,1添加到第二行,依此类推:

a = np.array([[2, 1], [4, 2], [-1, 2]])

def some_function(x, i):
    return x + i

res = map(lambda (i, row): some_function(row, i), enumerate(a))
print(res)

> [array([2, 1]), array([5, 3]), array([1, 4])]

我还没有找到相当于张量流中enumerate的等价物,但我不知道如何在张量流中获得相同的结果。有人知道使用什么来使它在tensorflow中工作吗?这是一个示例代码,我在a的每一行添加1:

import tensorflow as tf

a = tf.constant([[2, 1], [4, 2], [-1, 2]])

with tf.Session() as sess:
    res = tf.map_fn(lambda row: some_function(row, 1), a)
    print(res.eval())

> [[3 2]
   [5 3]
   [0 3]]

感谢任何可以帮助我解决这个问题的人。

1 个答案:

答案 0 :(得分:3)

tf.map_fn()可以有很多输入/输出。因此,您可以使用tf.range()来构建行索引的张量并使用它:

import tensorflow as tf

def some_function(x, i):
    return x + i

a = tf.constant([[2, 1], [4, 2], [-1, 2]])
a_rows = tf.expand_dims(tf.range(tf.shape(a)[0], dtype=tf.int32), 1)

res, _ = tf.map_fn(lambda x: (some_function(x[0], x[1]), x[1]), 
                   (a, a_rows), dtype=(tf.int32, tf.int32))

with tf.Session() as sess:
    print(res.eval())
    # [[2 1]
    #  [5 3]
    #  [1 4]]

注意:在许多情况下,“一行一行地处理矩阵”可以一次完成,例如通过广播,而不是使用循环:

import tensorflow as tf

a = tf.constant([[2, 1], [4, 2], [-1, 2]])
a_rows = tf.expand_dims(tf.range(tf.shape(a)[0], dtype=tf.int32), 1)
res = a + a_rows

with tf.Session() as sess:
    print(res.eval())
    # [[2 1]
    #  [5 3]
    #  [1 4]]