张量流多维索引

时间:2020-04-20 19:31:02

标签: python numpy tensorflow

我有

  • 4维浮点张量y

  • 3维整数张量y_index,包含要提取的y的4维索引

我想用numpy的3个for循环做的很简单:

y = np.random.randint(100,size=(5,10,20,3))
y_index= np.random.randint(3,size=(5,10,20))
y_slice = np.zeros_like(y_index)
for i in range(y.shape[0]):
    for j in range(y.shape[1]):
        for k in range(y.shape[2]):
            y_slice[i,j,k] = y[i,j,k,y_index[i,j,k]]
y_slice

如何在tensorflow中有效地做到这一点?我想我需要使用tf.gether_nd ...

1 个答案:

答案 0 :(得分:1)

您可以执行以下操作。基本上,首先将y的最后一个维度以外的所有维度展平,然后为展平的y创建一个索引。完成索引后,将其重塑为正确的形状。

y = tf.constant(np.random.normal(size=(5,10,20,3)), dtype='float32')
y_index = tf.constant(np.random.randint(3, size=(5,10,20)), dtype='int32')
# Creating an index like [(0,y_index[0]), (1, y_index[1]), ...]
inds = tf.stack([tf.range(5*10*20),tf.reshape(y_index,[-1])],axis=1)

y_slice = tf.reshape(tf.gather_nd(tf.reshape(y,[-1,3]),inds),[5,10,20])