Tensorflow:如何使用像numpy这样的2D索引来索引张量

时间:2017-03-31 16:10:16

标签: python python-3.x numpy indexing tensorflow

我想在Tensorflow中执行以下numpy代码:

input = np.array([[1,2,3]
                  [4,5,6]
                  [7,8,9]])
index1 = [0,1,2]
index2 = [2,2,0]
output = input[index1, index2]
>> output
[3,6,7]

给出如下输入:

input = tf.constant([[1, 2, 3],
                     [4, 5, 6],
                     [7, 8, 9]])

我尝试了以下内容,但似乎有点过头了:

index3 = tf.range(0, input.get_shape()[0])*input.get_shape()[1] + index2
output = tf.gather(tf.reshape(input, [-1]), index3)
sess = tf.Session()
sess.run(output)
>> [3,6,7]

这只是因为我的第一个索引很方便[0,1,2]但是对于[0,0,2]来说是不可行的(除了看起来真的很长很丑)。

你会有更简单的语法,更多的张力/ pythonic?

2 个答案:

答案 0 :(得分:4)

您可以使用tf.gather_nd (tf.gather_nd official doc)执行此操作,如下所示:

import tensorflow as tf
inp = tf.constant([[1, 2, 3],
                   [4, 5, 6],
                   [7, 8, 9]])
res=tf.gather_nd(inp,list(zip([0,1,2],[2,2,0])))
sess = tf.Session()
sess.run(res)

结果为array([3, 6, 7])

答案 1 :(得分:3)

如何使用tf.gather_nd

In [61]: input = tf.constant([[1, 2, 3],
    ...:                      [4, 5, 6],
    ...:                      [7, 8, 9]])

In [63]: row_idx = tf.constant([0, 1, 2])
In [64]: col_idx = tf.constant([2, 2, 0])
In [65]: coords = tf.transpose(tf.pack([row_idx, col_idx]))

In [67]: sess = tf.InteractiveSession()

In [68]: tf.gather_nd(input, coords).eval()
Out[68]: array([3, 6, 7], dtype=int32)