如何在TensorFlow中沿特定轴选择张量的元素

时间:2019-09-22 20:01:51

标签: tensorflow

我有一个(2, 3, 4)形状的张量t = tf.random.normal((2, 3, 4)) <tf.Tensor: id=55, shape=(2, 3, 4), dtype=float32, numpy= array([[[-0.86664855, -0.32786712, -0.9517335 , 0.989722 ], [-0.25011402, -0.35941386, -1.0808105 , 0.60205466], [ 0.07523973, -0.6512919 , 1.3695312 , -1.5043781 ]], [[ 0.33990988, -0.17364176, 0.72955394, -0.7119293 ], [ 0.4013214 , 0.5653289 , 1.4327284 , 1.2687784 ], [-1.1986154 , 1.3783301 , 1.714094 , 0.49866664]]], dtype=float32)>

idx

和一组大小为(2, 3)的索引t,其值沿idx = tf.convert_to_tensor(np.random.randint(4, size=(2, 3))) <tf.Tensor: id=56, shape=(2, 3), dtype=int64, numpy= array([[2, 2, 3], [0, 3, 1]])> 的最后一个维度进行索引

t

如何在idx指定的索引处沿(2, 3)的最后维度提取元素?结果应该是形状为<tf.Tensor: id=57, shape=(2, 3), dtype=int64, numpy= array([[-0.9517335, -1.0808105, -1.5043781], [0.33990988, 1.2687784, 1.3783301]])> 的以下张量。

t[:, :, idx]  # error
t[..., idx]   # error

我一直在尝试定期索引编制失败

tf.gather(t, idx, axis=2)  # has shape (2, 3, 2, 3)
tf.gather_nd(t, idx)       # has shape (2, )

tf.gather / tf.gather_nd

{{1}}

似乎都没有做到这一点。

1 个答案:

答案 0 :(得分:0)

再次考虑您要实现的目标。您要为第一轴和第二轴提取的元素的索引是什么?从您的示例中,您似乎正在考虑展平前两个维度,使t为(6,4),并提取其第一维索引为0:6且第二维索引由{{ 1}}。

要实现此目的,必须实际指定所有尺寸的索引。我们可以先将idx重塑为2D:

t

现在,我们将指定第一个轴的索引:

t_2d=tf.reshape(t,[-1,tf.shape(t)[-1]])

<tf.Tensor: id=55, shape=(6, 4), dtype=float32, numpy=
array([[-0.86664855, -0.32786712, -0.9517335 ,  0.989722  ],
       [-0.25011402, -0.35941386, -1.0808105 ,  0.60205466],
       [ 0.07523973, -0.6512919 ,  1.3695312 , -1.5043781 ],
       [ 0.33990988, -0.17364176,  0.72955394, -0.7119293 ],
       [ 0.4013214 ,  0.5653289 ,  1.4327284 ,  1.2687784 ],
       [-1.1986154 ,  1.3783301 ,  1.714094  ,  0.49866664]],
      dtype=float32)>

idx_0=tf.reshape(tf.range(t_2d.shape[0]),idx.shape) <tf.Tensor: id=62, shape=(2, 3), dtype=int32, numpy= array([[0, 1, 2], [3, 4, 5]], dtype=int32)> 的期望加入第一和第二轴的索引:

tf.gather_nd

最后:

indices=tf.stack([idx_0,idx],axis=-1)

<tf.Tensor: id=64, shape=(2, 3, 2), dtype=int32, numpy=
array([[[0, 2],
        [1, 2],
        [2, 3]],

       [[3, 0],
        [4, 3],
        [5, 1]]], dtype=int32)>