假设我有一个张量X
,其秩为2,第一个秩对应于某些样本x
的批量大小,例如维数为K。很容易访问其中的第k个元素所有样本:X[1:batch_size,k]
。但是,假设我需要为所有i访问x_i的第k_i个元素。例如,如果我有k_list = [1, 2, ..., 2]
,那么我知道访问所有i的x_i的第k_i个元素的唯一方法是
out=[X[i,k_list[i]] for all i in range(len(k_list))]
问题是,这使我的代码真正变慢了。反正我们可以优化此代码吗?
注意*:我实际上有k_list
作为占位符。 np.shape(X)=(batch_size,K)
,np.shape(k_list)=(batch_size,)
,np.maximum(k_list)=K-1, np.minimum(k_list)=0
和np.shape(out)=(batch_size,1)
答案 0 :(得分:1)
如果我正确理解了您的问题,那么您正在寻找gather_nd
i0 = tf.range(batch_size, dtype=tf.int32)
indices = tf.stack((i0, k_list), axis=1)
out = tf.gather_nd(X, indices)