我有一个(2, 3, 4)
形状的张量t = tf.random.normal((2, 3, 4))
<tf.Tensor: id=55, shape=(2, 3, 4), dtype=float32, numpy=
array([[[-0.86664855, -0.32786712, -0.9517335 , 0.989722 ],
[-0.25011402, -0.35941386, -1.0808105 , 0.60205466],
[ 0.07523973, -0.6512919 , 1.3695312 , -1.5043781 ]],
[[ 0.33990988, -0.17364176, 0.72955394, -0.7119293 ],
[ 0.4013214 , 0.5653289 , 1.4327284 , 1.2687784 ],
[-1.1986154 , 1.3783301 , 1.714094 , 0.49866664]]],
dtype=float32)>
idx
和一组大小为(2, 3)
的索引t
,其值沿idx = tf.convert_to_tensor(np.random.randint(4, size=(2, 3)))
<tf.Tensor: id=56, shape=(2, 3), dtype=int64, numpy=
array([[2, 2, 3],
[0, 3, 1]])>
的最后一个维度进行索引
t
如何在idx
指定的索引处沿(2, 3)
的最后维度提取元素?结果应该是形状为<tf.Tensor: id=57, shape=(2, 3), dtype=int64, numpy=
array([[-0.9517335, -1.0808105, -1.5043781],
[0.33990988, 1.2687784, 1.3783301]])>
的以下张量。
t[:, :, idx] # error
t[..., idx] # error
我一直在尝试定期索引编制失败
tf.gather(t, idx, axis=2) # has shape (2, 3, 2, 3)
tf.gather_nd(t, idx) # has shape (2, )
{{1}}
似乎都没有做到这一点。
答案 0 :(得分:0)
再次考虑您要实现的目标。您要为第一轴和第二轴提取的元素的索引是什么?从您的示例中,您似乎正在考虑展平前两个维度,使t
为(6,4),并提取其第一维索引为0:6且第二维索引由{{ 1}}。
要实现此目的,必须实际指定所有尺寸的索引。我们可以先将idx
重塑为2D:
t
现在,我们将指定第一个轴的索引:
t_2d=tf.reshape(t,[-1,tf.shape(t)[-1]])
<tf.Tensor: id=55, shape=(6, 4), dtype=float32, numpy=
array([[-0.86664855, -0.32786712, -0.9517335 , 0.989722 ],
[-0.25011402, -0.35941386, -1.0808105 , 0.60205466],
[ 0.07523973, -0.6512919 , 1.3695312 , -1.5043781 ],
[ 0.33990988, -0.17364176, 0.72955394, -0.7119293 ],
[ 0.4013214 , 0.5653289 , 1.4327284 , 1.2687784 ],
[-1.1986154 , 1.3783301 , 1.714094 , 0.49866664]],
dtype=float32)>
按idx_0=tf.reshape(tf.range(t_2d.shape[0]),idx.shape)
<tf.Tensor: id=62, shape=(2, 3), dtype=int32, numpy=
array([[0, 1, 2],
[3, 4, 5]], dtype=int32)>
的期望加入第一和第二轴的索引:
tf.gather_nd
最后:
indices=tf.stack([idx_0,idx],axis=-1)
<tf.Tensor: id=64, shape=(2, 3, 2), dtype=int32, numpy=
array([[[0, 2],
[1, 2],
[2, 3]],
[[3, 0],
[4, 3],
[5, 1]]], dtype=int32)>