访问张量中各个元素的更好方法

时间:2019-01-31 07:12:05

标签: python python-3.x tensorflow

我正在尝试使用张量a中定义的索引来访问张量b的元素。

a=tf.constant([[1,2,3,4],[5,6,7,8]])
b=tf.constant([0,1,1,0])

我希望输出为

out = [1 6 7 4]

我尝试过的事情:

out=[]
for i in range(a.shape[1]):
    out.append(a[b[i],i])

out=tf.stack(out) #[1 6 7 4]

这给出了正确的输出,但是我正在寻找一种更好,更紧凑的方式来实现。

a的形状类似于(2,None)时,我的逻辑也不起作用,因为我无法使用range(a.shape[1])进行迭代,如果答案也包括这种情况,这将对我有帮助

谢谢

1 个答案:

答案 0 :(得分:2)

您可以使用tf.one_hot()tf.boolean_mask()

import tensorflow as tf
import numpy as np

a_tf = tf.placeholder(shape=(2,None),dtype=tf.int32)
b_tf = tf.placeholder(shape=(None,),dtype=tf.int32)

index = tf.one_hot(b_tf,a_tf.shape[0])
out = tf.boolean_mask(tf.transpose(a_tf),index)

a=np.array([[1,2,3,4],[5,6,7,8]])
b=np.array([0,1,1,0])
with tf.Session() as sess:
    print(sess.run(out,feed_dict={a_tf:a,b_tf:b}))

# print
[1 6 7 4]