我想在Tensorflow中执行以下numpy代码:
input = np.array([[1,2,3]
[4,5,6]
[7,8,9]])
index1 = [0,1,2]
index2 = [2,2,0]
output = input[index1, index2]
>> output
[3,6,7]
给出如下输入:
input = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
我尝试了以下内容,但似乎有点过头了:
index3 = tf.range(0, input.get_shape()[0])*input.get_shape()[1] + index2
output = tf.gather(tf.reshape(input, [-1]), index3)
sess = tf.Session()
sess.run(output)
>> [3,6,7]
这只是因为我的第一个索引很方便[0,1,2]但是对于[0,0,2]来说是不可行的(除了看起来真的很长很丑)。
你会有更简单的语法,更多的张力/ pythonic?
答案 0 :(得分:4)
您可以使用tf.gather_nd
(tf.gather_nd official doc)执行此操作,如下所示:
import tensorflow as tf
inp = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
res=tf.gather_nd(inp,list(zip([0,1,2],[2,2,0])))
sess = tf.Session()
sess.run(res)
结果为array([3, 6, 7])
答案 1 :(得分:3)
如何使用tf.gather_nd?
In [61]: input = tf.constant([[1, 2, 3],
...: [4, 5, 6],
...: [7, 8, 9]])
In [63]: row_idx = tf.constant([0, 1, 2])
In [64]: col_idx = tf.constant([2, 2, 0])
In [65]: coords = tf.transpose(tf.pack([row_idx, col_idx]))
In [67]: sess = tf.InteractiveSession()
In [68]: tf.gather_nd(input, coords).eval()
Out[68]: array([3, 6, 7], dtype=int32)