我有一个张量,我使用tf.map_fn
逐行处理。现在我想将索引作为参数包含在我传递给tf.map_fn
的函数中。在numpy中我可以使用enumerate
获取该信息并将其传递给我的lambda函数。这是numpy中的一个例子,我将0添加到第一行,1添加到第二行,依此类推:
a = np.array([[2, 1], [4, 2], [-1, 2]])
def some_function(x, i):
return x + i
res = map(lambda (i, row): some_function(row, i), enumerate(a))
print(res)
> [array([2, 1]), array([5, 3]), array([1, 4])]
我还没有找到相当于张量流中enumerate
的等价物,但我不知道如何在张量流中获得相同的结果。有人知道使用什么来使它在tensorflow中工作吗?这是一个示例代码,我在a
的每一行添加1:
import tensorflow as tf
a = tf.constant([[2, 1], [4, 2], [-1, 2]])
with tf.Session() as sess:
res = tf.map_fn(lambda row: some_function(row, 1), a)
print(res.eval())
> [[3 2]
[5 3]
[0 3]]
感谢任何可以帮助我解决这个问题的人。
答案 0 :(得分:3)
tf.map_fn()
可以有很多输入/输出。因此,您可以使用tf.range()
来构建行索引的张量并使用它:
import tensorflow as tf
def some_function(x, i):
return x + i
a = tf.constant([[2, 1], [4, 2], [-1, 2]])
a_rows = tf.expand_dims(tf.range(tf.shape(a)[0], dtype=tf.int32), 1)
res, _ = tf.map_fn(lambda x: (some_function(x[0], x[1]), x[1]),
(a, a_rows), dtype=(tf.int32, tf.int32))
with tf.Session() as sess:
print(res.eval())
# [[2 1]
# [5 3]
# [1 4]]
注意:在许多情况下,“一行一行地处理矩阵”可以一次完成,例如通过广播,而不是使用循环:
import tensorflow as tf
a = tf.constant([[2, 1], [4, 2], [-1, 2]])
a_rows = tf.expand_dims(tf.range(tf.shape(a)[0], dtype=tf.int32), 1)
res = a + a_rows
with tf.Session() as sess:
print(res.eval())
# [[2 1]
# [5 3]
# [1 4]]