在Tensorflow中访问张量中的条件索引

时间:2018-08-22 10:48:00

标签: python tensorflow

假设我有一个张量X,其秩为2,第一个秩对应于某些样本x的批量大小,例如维数为K。很容易访问其中的第k个元素所有样本:X[1:batch_size,k]。但是,假设我需要为所有i访问x_i的第k_i个元素。例如,如果我有k_list = [1, 2, ..., 2],那么我知道访问所有i的x_i的第k_i个元素的唯一方法是

out=[X[i,k_list[i]] for all i in range(len(k_list))]

问题是,这使我的代码真正变慢了。反正我们可以优化此代码吗?

注意*:我实际上有k_list作为占位符。 np.shape(X)=(batch_size,K)np.shape(k_list)=(batch_size,)np.maximum(k_list)=K-1, np.minimum(k_list)=0np.shape(out)=(batch_size,1)

的大小

1 个答案:

答案 0 :(得分:1)

如果我正确理解了您的问题,那么您正在寻找gather_nd

i0 = tf.range(batch_size, dtype=tf.int32)
indices = tf.stack((i0, k_list), axis=1)
out = tf.gather_nd(X, indices)