TensorFlow:使用张量来索引另一个张量

时间:2016-03-07 11:31:53

标签: python numpy tensorflow

我有一个关于如何在TensorFlow中进行索引的基本问题。

numpy:

x = np.asarray([1,2,3,3,2,5,6,7,1,3])
e = np.asarray([0,1,0,1,1,1,0,1])
#numpy 
print x * e[x]

我可以

[1 0 3 3 0 5 0 7 1 3]

如何在TensorFlow中执行此操作?

x = np.asarray([1,2,3,3,2,5,6,7,1,3])
e = np.asarray([0,1,0,1,1,1,0,1])
x_t = tf.constant(x)
e_t = tf.constant(e)
with tf.Session():
    ????

谢谢!

1 个答案:

答案 0 :(得分:29)

幸运的是,tf.gather()在TensorFlow中支持您提出的确切案例:

result = x_t * tf.gather(e_t, x_t)

with tf.Session() as sess:
    print sess.run(result)  # ==> 'array([1, 0, 3, 3, 0, 5, 0, 7, 1, 3])'

tf.gather() op的功能不如NumPy's advanced indexing:它只支持在其第0维上提取张量的完整切片。已经请求支持更一般的索引,并且正在this GitHub issue中跟踪。